""" --- title: Denoising Diffusion Probabilistic Models (DDPM) training summary: > Training code for Denoising Diffusion Probabilistic Model. --- # [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/diffusion/ddpm/experiment.ipynb) This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this [discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3). Save the images inside [`data/celebA` folder](#dataset_path). The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for simplicity. """ from typing import List import torch import torch.utils.data import torchvision from PIL import Image from labml import lab, tracker, experiment, monit from labml.configs import BaseConfigs, option from labml_helpers.device import DeviceConfigs from labml_nn.diffusion.ddpm import DenoiseDiffusion from labml_nn.diffusion.ddpm.unet import UNet class Configs(BaseConfigs): """ ## Configurations """ # Device to train the model on. # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs) # picks up an available CUDA device or defaults to CPU. device: torch.device = DeviceConfigs() # U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ eps_model: UNet # [DDPM algorithm](index.html) diffusion: DenoiseDiffusion # Number of channels in the image. $3$ for RGB. image_channels: int = 3 # Image size image_size: int = 32 # Number of channels in the initial feature map n_channels: int = 64 # The list of channel numbers at each resolution. # The number of channels is `channel_multipliers[i] * n_channels` channel_multipliers: List[int] = [1, 2, 2, 4] # The list of booleans that indicate whether to use attention at each resolution is_attention: List[int] = [False, False, False, True] # Number of time steps $T$ n_steps: int = 1_000 # Batch size batch_size: int = 64 # Number of samples to generate n_samples: int = 16 # Learning rate learning_rate: float = 2e-5 # Number of training epochs epochs: int = 1_000 # Dataset dataset: torch.utils.data.Dataset # Dataloader data_loader: torch.utils.data.DataLoader # Adam optimizer optimizer: torch.optim.Adam def init(self): # Create $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model self.eps_model = UNet( image_channels=self.image_channels, n_channels=self.n_channels, ch_mults=self.channel_multipliers, is_attn=self.is_attention, ).to(self.device) # Create [DDPM class](index.html) self.diffusion = DenoiseDiffusion( eps_model=self.eps_model, n_steps=self.n_steps, device=self.device, ) # Create dataloader self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True) # Create optimizer self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate) # Image logging tracker.set_image("sample", True) def sample(self): """ ### Sample images """ with torch.no_grad(): # $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$ x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size], device=self.device) # Remove noise for $T$ steps for t_ in monit.iterate('Sample', self.n_steps): # $t$ t = self.n_steps - t_ - 1 # Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$ x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long)) # Log samples tracker.save('sample', x) def train(self): """ ### Train """ # Iterate through the dataset for data in monit.iterate('Train', self.data_loader): # Increment global step tracker.add_global_step() # Move data to device data = data.to(self.device) # Make the gradients zero self.optimizer.zero_grad() # Calculate loss loss = self.diffusion.loss(data) # Compute gradients loss.backward() # Take an optimization step self.optimizer.step() # Track the loss tracker.save('loss', loss) def run(self): """ ### Training loop """ for _ in monit.loop(self.epochs): # Train the model self.train() # Sample some images self.sample() # New line in the console tracker.new_line() # Save the model experiment.save_checkpoint() class CelebADataset(torch.utils.data.Dataset): """ ### CelebA HQ dataset """ def __init__(self, image_size: int): super().__init__() # CelebA images folder folder = lab.get_data_path() / 'celebA' # List of files self._files = [p for p in folder.glob(f'**/*.jpg')] # Transformations to resize the image and convert to tensor self._transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(image_size), torchvision.transforms.ToTensor(), ]) def __len__(self): """ Size of the dataset """ return len(self._files) def __getitem__(self, index: int): """ Get an image """ img = Image.open(self._files[index]) return self._transform(img) @option(Configs.dataset, 'CelebA') def celeb_dataset(c: Configs): """ Create CelebA dataset """ return CelebADataset(c.image_size) class MNISTDataset(torchvision.datasets.MNIST): """ ### MNIST dataset """ def __init__(self, image_size): transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(image_size), torchvision.transforms.ToTensor(), ]) super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform) def __getitem__(self, item): return super().__getitem__(item)[0] @option(Configs.dataset, 'MNIST') def mnist_dataset(c: Configs): """ Create MNIST dataset """ return MNISTDataset(c.image_size) def main(): # Create experiment experiment.create(name='diffuse', writers={'screen', 'labml'}) # Create configurations configs = Configs() # Set configurations. You can override the defaults by passing the values in the dictionary. experiment.configs(configs, { 'dataset': 'CelebA', # 'MNIST' 'image_channels': 3, # 1, 'epochs': 100, # 5, }) # Initialize configs.init() # Set models for saving and loading experiment.add_pytorch_models({'eps_model': configs.eps_model}) # Start and run the training loop with experiment.start(): configs.run() # if __name__ == '__main__': main()