mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-11-04 06:16:05 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			250 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			250 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""
 | 
						|
---
 | 
						|
title: Denoising Diffusion Probabilistic Models (DDPM) training
 | 
						|
summary: >
 | 
						|
  Training code for
 | 
						|
  Denoising Diffusion Probabilistic Model.
 | 
						|
---
 | 
						|
 | 
						|
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
 | 
						|
 | 
						|
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
 | 
						|
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
 | 
						|
Save the images inside [`data/celebA` folder](#dataset_path).
 | 
						|
 | 
						|
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
 | 
						|
simplicity.
 | 
						|
"""
 | 
						|
from typing import List
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.utils.data
 | 
						|
import torchvision
 | 
						|
from PIL import Image
 | 
						|
 | 
						|
from labml import lab, tracker, experiment, monit
 | 
						|
from labml.configs import BaseConfigs, option
 | 
						|
from labml_helpers.device import DeviceConfigs
 | 
						|
from labml_nn.diffusion.ddpm import DenoiseDiffusion
 | 
						|
from labml_nn.diffusion.ddpm.unet import UNet
 | 
						|
 | 
						|
 | 
						|
class Configs(BaseConfigs):
 | 
						|
    """
 | 
						|
    ## Configurations
 | 
						|
    """
 | 
						|
    # Device to train the model on.
 | 
						|
    # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
 | 
						|
    #  picks up an available CUDA device or defaults to CPU.
 | 
						|
    device: torch.device = DeviceConfigs()
 | 
						|
 | 
						|
    # U-Net model for $\textcolor{cyan}{\epsilon_\theta}(x_t, t)$
 | 
						|
    eps_model: UNet
 | 
						|
    # [DDPM algorithm](index.html)
 | 
						|
    diffusion: DenoiseDiffusion
 | 
						|
 | 
						|
    # Number of channels in the image. $3$ for RGB.
 | 
						|
    image_channels: int = 3
 | 
						|
    # Image size
 | 
						|
    image_size: int = 32
 | 
						|
    # Number of channels in the initial feature map
 | 
						|
    n_channels: int = 64
 | 
						|
    # The list of channel numbers at each resolution.
 | 
						|
    # The number of channels is `channel_multipliers[i] * n_channels`
 | 
						|
    channel_multipliers: List[int] = [1, 2, 2, 4]
 | 
						|
    # The list of booleans that indicate whether to use attention at each resolution
 | 
						|
    is_attention: List[int] = [False, False, False, True]
 | 
						|
 | 
						|
    # Number of time steps $T$
 | 
						|
    n_steps: int = 1_000
 | 
						|
    # Batch size
 | 
						|
    batch_size: int = 64
 | 
						|
    # Number of samples to generate
 | 
						|
    n_samples: int = 16
 | 
						|
    # Learning rate
 | 
						|
    learning_rate: float = 2e-5
 | 
						|
 | 
						|
    # Number of training epochs
 | 
						|
    epochs: int = 1_000
 | 
						|
 | 
						|
    # Dataset
 | 
						|
    dataset: torch.utils.data.Dataset
 | 
						|
    # Dataloader
 | 
						|
    data_loader: torch.utils.data.DataLoader
 | 
						|
 | 
						|
    # Adam optimizer
 | 
						|
    optimizer: torch.optim.Adam
 | 
						|
 | 
						|
    def init(self):
 | 
						|
        # Create $\textcolor{cyan}{\epsilon_\theta}(x_t, t)$ model
 | 
						|
        self.eps_model = UNet(
 | 
						|
            image_channels=self.image_channels,
 | 
						|
            n_channels=self.n_channels,
 | 
						|
            ch_mults=self.channel_multipliers,
 | 
						|
            is_attn=self.is_attention,
 | 
						|
        ).to(self.device)
 | 
						|
 | 
						|
        # Create [DDPM class](index.html)
 | 
						|
        self.diffusion = DenoiseDiffusion(
 | 
						|
            eps_model=self.eps_model,
 | 
						|
            n_steps=self.n_steps,
 | 
						|
            device=self.device,
 | 
						|
        )
 | 
						|
 | 
						|
        # Create dataloader
 | 
						|
        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
 | 
						|
        # Create optimizer
 | 
						|
        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
 | 
						|
 | 
						|
        # Image logging
 | 
						|
        tracker.set_image("sample", True)
 | 
						|
 | 
						|
    def sample(self):
 | 
						|
        """
 | 
						|
        ### Sample images
 | 
						|
        """
 | 
						|
        with torch.no_grad():
 | 
						|
            # $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
 | 
						|
            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
 | 
						|
                            device=self.device)
 | 
						|
 | 
						|
            # Remove noise for $T$ steps
 | 
						|
            for t_ in monit.iterate('Sample', self.n_steps):
 | 
						|
                # $t$
 | 
						|
                t = self.n_steps - t_ - 1
 | 
						|
                # Sample from $\textcolor{cyan}{p_\theta}(x_{t-1}|x_t)$
 | 
						|
                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
 | 
						|
 | 
						|
            # Log samples
 | 
						|
            tracker.save('sample', x)
 | 
						|
 | 
						|
    def train(self):
 | 
						|
        """
 | 
						|
        ### Train
 | 
						|
        """
 | 
						|
 | 
						|
        # Iterate through the dataset
 | 
						|
        for data in monit.iterate('Train', self.data_loader):
 | 
						|
            # Increment global step
 | 
						|
            tracker.add_global_step()
 | 
						|
            # Move data to device
 | 
						|
            data = data.to(self.device)
 | 
						|
 | 
						|
            # Make the gradients zero
 | 
						|
            self.optimizer.zero_grad()
 | 
						|
            # Calculate loss
 | 
						|
            loss = self.diffusion.loss(data)
 | 
						|
            # Compute gradients
 | 
						|
            loss.backward()
 | 
						|
            # Take an optimization step
 | 
						|
            self.optimizer.step()
 | 
						|
            # Track the loss
 | 
						|
            tracker.save('loss', loss)
 | 
						|
 | 
						|
    def run(self):
 | 
						|
        """
 | 
						|
        ### Training loop
 | 
						|
        """
 | 
						|
        for _ in monit.loop(self.epochs):
 | 
						|
            # Train the model
 | 
						|
            self.train()
 | 
						|
            # Sample some images
 | 
						|
            self.sample()
 | 
						|
            # New line in the console
 | 
						|
            tracker.new_line()
 | 
						|
            # Save the model
 | 
						|
            experiment.save_checkpoint()
 | 
						|
 | 
						|
 | 
						|
class CelebADataset(torch.utils.data.Dataset):
 | 
						|
    """
 | 
						|
    ### CelebA HQ dataset
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, image_size: int):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        # CelebA images folder
 | 
						|
        folder = lab.get_data_path() / 'celebA'
 | 
						|
        # List of files
 | 
						|
        self._files = [p for p in folder.glob(f'**/*.jpg')]
 | 
						|
 | 
						|
        # Transformations to resize the image and convert to tensor
 | 
						|
        self._transform = torchvision.transforms.Compose([
 | 
						|
            torchvision.transforms.Resize(image_size),
 | 
						|
            torchvision.transforms.ToTensor(),
 | 
						|
        ])
 | 
						|
 | 
						|
    def __len__(self):
 | 
						|
        """
 | 
						|
        Size of the dataset
 | 
						|
        """
 | 
						|
        return len(self._files)
 | 
						|
 | 
						|
    def __getitem__(self, index: int):
 | 
						|
        """
 | 
						|
        Get an image
 | 
						|
        """
 | 
						|
        img = Image.open(self._files[index])
 | 
						|
        return self._transform(img)
 | 
						|
 | 
						|
 | 
						|
@option(Configs.dataset, 'CelebA')
 | 
						|
def celeb_dataset(c: Configs):
 | 
						|
    """
 | 
						|
    Create CelebA dataset
 | 
						|
    """
 | 
						|
    return CelebADataset(c.image_size)
 | 
						|
 | 
						|
 | 
						|
class MNISTDataset(torchvision.datasets.MNIST):
 | 
						|
    """
 | 
						|
    ### MNIST dataset
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, image_size):
 | 
						|
        transform = torchvision.transforms.Compose([
 | 
						|
            torchvision.transforms.Resize(image_size),
 | 
						|
            torchvision.transforms.ToTensor(),
 | 
						|
        ])
 | 
						|
 | 
						|
        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
 | 
						|
 | 
						|
    def __getitem__(self, item):
 | 
						|
        return super().__getitem__(item)[0]
 | 
						|
 | 
						|
 | 
						|
@option(Configs.dataset, 'MNIST')
 | 
						|
def mnist_dataset(c: Configs):
 | 
						|
    """
 | 
						|
    Create MNIST dataset
 | 
						|
    """
 | 
						|
    return MNISTDataset(c.image_size)
 | 
						|
 | 
						|
 | 
						|
def main():
 | 
						|
    # Create experiment
 | 
						|
    experiment.create(name='diffuse')
 | 
						|
 | 
						|
    # Create configurations
 | 
						|
    configs = Configs()
 | 
						|
 | 
						|
    # Set configurations. You can override the defaults by passing the values in the dictionary.
 | 
						|
    experiment.configs(configs, {
 | 
						|
    })
 | 
						|
 | 
						|
    # Initialize
 | 
						|
    configs.init()
 | 
						|
 | 
						|
    # Set models for saving and loading
 | 
						|
    experiment.add_pytorch_models({'eps_model': configs.eps_model})
 | 
						|
 | 
						|
    # Start and run the training loop
 | 
						|
    with experiment.start():
 | 
						|
        configs.run()
 | 
						|
 | 
						|
 | 
						|
#
 | 
						|
if __name__ == '__main__':
 | 
						|
    main()
 |