This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA
 folder.
The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.
20from typing import List
21
22import torch
23import torch.utils.data
24import torchvision
25from PIL import Image
26
27from labml import lab, tracker, experiment, monit
28from labml.configs import BaseConfigs, option
29from labml_helpers.device import DeviceConfigs
30from labml_nn.diffusion.ddpm import DenoiseDiffusion
31from labml_nn.diffusion.ddpm.unet import UNet34class Configs(BaseConfigs):Device to train the model on. DeviceConfigs
  picks up an available CUDA device or defaults to CPU. 
41    device: torch.device = DeviceConfigs()U-Net model for
44    eps_model: UNet46    diffusion: DenoiseDiffusionNumber of channels in the image. for RGB.
49    image_channels: int = 3Image size
51    image_size: int = 32Number of channels in the initial feature map
53    n_channels: int = 64The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels
 
56    channel_multipliers: List[int] = [1, 2, 2, 4]The list of booleans that indicate whether to use attention at each resolution
58    is_attention: List[int] = [False, False, False, True]Number of time steps
61    n_steps: int = 1_000Batch size
63    batch_size: int = 64Number of samples to generate
65    n_samples: int = 16Learning rate
67    learning_rate: float = 2e-5Number of training epochs
70    epochs: int = 1_000Dataset
73    dataset: torch.utils.data.DatasetDataloader
75    data_loader: torch.utils.data.DataLoaderAdam optimizer
78    optimizer: torch.optim.Adam80    def init(self):Create model
82        self.eps_model = UNet(
83            image_channels=self.image_channels,
84            n_channels=self.n_channels,
85            ch_mults=self.channel_multipliers,
86            is_attn=self.is_attention,
87        ).to(self.device)Create DDPM class
90        self.diffusion = DenoiseDiffusion(
91            eps_model=self.eps_model,
92            n_steps=self.n_steps,
93            device=self.device,
94        )Create dataloader
97        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)Create optimizer
99        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)Image logging
102        tracker.set_image("sample", True)104    def sample(self):108        with torch.no_grad():110            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
111                            device=self.device)Remove noise for steps
114            for t_ in monit.iterate('Sample', self.n_steps):116                t = self.n_steps - t_ - 1Sample from
118                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))Log samples
121            tracker.save('sample', x)123    def train(self):Iterate through the dataset
129        for data in monit.iterate('Train', self.data_loader):Increment global step
131            tracker.add_global_step()Move data to device
133            data = data.to(self.device)Make the gradients zero
136            self.optimizer.zero_grad()Calculate loss
138            loss = self.diffusion.loss(data)Compute gradients
140            loss.backward()Take an optimization step
142            self.optimizer.step()Track the loss
144            tracker.save('loss', loss)146    def run(self):150        for _ in monit.loop(self.epochs):Train the model
152            self.train()Sample some images
154            self.sample()New line in the console
156            tracker.new_line()Save the model
158            experiment.save_checkpoint()161class CelebADataset(torch.utils.data.Dataset):166    def __init__(self, image_size: int):
167        super().__init__()CelebA images folder
170        folder = lab.get_data_path() / 'celebA'List of files
172        self._files = [p for p in folder.glob(f'**/*.jpg')]Transformations to resize the image and convert to tensor
175        self._transform = torchvision.transforms.Compose([
176            torchvision.transforms.Resize(image_size),
177            torchvision.transforms.ToTensor(),
178        ])Size of the dataset
180    def __len__(self):184        return len(self._files)Get an image
186    def __getitem__(self, index: int):190        img = Image.open(self._files[index])
191        return self._transform(img)Create CelebA dataset
194@option(Configs.dataset, 'CelebA')
195def celeb_dataset(c: Configs):199    return CelebADataset(c.image_size)202class MNISTDataset(torchvision.datasets.MNIST):207    def __init__(self, image_size):
208        transform = torchvision.transforms.Compose([
209            torchvision.transforms.Resize(image_size),
210            torchvision.transforms.ToTensor(),
211        ])
212
213        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)215    def __getitem__(self, item):
216        return super().__getitem__(item)[0]Create MNIST dataset
219@option(Configs.dataset, 'MNIST')
220def mnist_dataset(c: Configs):224    return MNISTDataset(c.image_size)227def main():Create experiment
229    experiment.create(name='diffuse', writers={'screen', 'labml'})Create configurations
232    configs = Configs()Set configurations. You can override the defaults by passing the values in the dictionary.
235    experiment.configs(configs, {
236        'dataset': 'CelebA',  # 'MNIST'
237        'image_channels': 3,  # 1,
238        'epochs': 100,  # 5,
239    })Initialize
242    configs.init()Set models for saving and loading
245    experiment.add_pytorch_models({'eps_model': configs.eps_model})Start and run the training loop
248    with experiment.start():
249        configs.run()253if __name__ == '__main__':
254    main()