Denoising Diffusion Probabilistic Models (DDPM) training

Open In Colab Open In Comet

This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA folder.

The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.

21from typing import List
22
23import torch
24import torch.utils.data
25import torchvision
26from PIL import Image
27
28from labml import lab, tracker, experiment, monit
29from labml.configs import BaseConfigs, option
30from labml_helpers.device import DeviceConfigs
31from labml_nn.diffusion.ddpm import DenoiseDiffusion
32from labml_nn.diffusion.ddpm.unet import UNet

Configurations

35class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

42    device: torch.device = DeviceConfigs()

U-Net model for

45    eps_model: UNet
47    diffusion: DenoiseDiffusion

Number of channels in the image. for RGB.

50    image_channels: int = 3

Image size

52    image_size: int = 32

Number of channels in the initial feature map

54    n_channels: int = 64

The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels

57    channel_multipliers: List[int] = [1, 2, 2, 4]

The list of booleans that indicate whether to use attention at each resolution

59    is_attention: List[int] = [False, False, False, True]

Number of time steps

62    n_steps: int = 1_000

Batch size

64    batch_size: int = 64

Number of samples to generate

66    n_samples: int = 16

Learning rate

68    learning_rate: float = 2e-5

Number of training epochs

71    epochs: int = 1_000

Dataset

74    dataset: torch.utils.data.Dataset

Dataloader

76    data_loader: torch.utils.data.DataLoader

Adam optimizer

79    optimizer: torch.optim.Adam
81    def init(self):

Create model

83        self.eps_model = UNet(
84            image_channels=self.image_channels,
85            n_channels=self.n_channels,
86            ch_mults=self.channel_multipliers,
87            is_attn=self.is_attention,
88        ).to(self.device)

Create DDPM class

91        self.diffusion = DenoiseDiffusion(
92            eps_model=self.eps_model,
93            n_steps=self.n_steps,
94            device=self.device,
95        )

Create dataloader

98        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

Create optimizer

100        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

Image logging

103        tracker.set_image("sample", True)

Sample images

105    def sample(self):
109        with torch.no_grad():

111            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
112                            device=self.device)

Remove noise for steps

115            for t_ in monit.iterate('Sample', self.n_steps):

117                t = self.n_steps - t_ - 1

Sample from

119                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

Log samples

122            tracker.save('sample', x)

Train

124    def train(self):

Iterate through the dataset

130        for data in monit.iterate('Train', self.data_loader):

Increment global step

132            tracker.add_global_step()

Move data to device

134            data = data.to(self.device)

Make the gradients zero

137            self.optimizer.zero_grad()

Calculate loss

139            loss = self.diffusion.loss(data)

Compute gradients

141            loss.backward()

Take an optimization step

143            self.optimizer.step()

Track the loss

145            tracker.save('loss', loss)

Training loop

147    def run(self):
151        for _ in monit.loop(self.epochs):

Train the model

153            self.train()

Sample some images

155            self.sample()

New line in the console

157            tracker.new_line()

Save the model

159            experiment.save_checkpoint()

CelebA HQ dataset

162class CelebADataset(torch.utils.data.Dataset):
167    def __init__(self, image_size: int):
168        super().__init__()

CelebA images folder

171        folder = lab.get_data_path() / 'celebA'

List of files

173        self._files = [p for p in folder.glob(f'**/*.jpg')]

Transformations to resize the image and convert to tensor

176        self._transform = torchvision.transforms.Compose([
177            torchvision.transforms.Resize(image_size),
178            torchvision.transforms.ToTensor(),
179        ])

Size of the dataset

181    def __len__(self):
185        return len(self._files)

Get an image

187    def __getitem__(self, index: int):
191        img = Image.open(self._files[index])
192        return self._transform(img)

Create CelebA dataset

195@option(Configs.dataset, 'CelebA')
196def celeb_dataset(c: Configs):
200    return CelebADataset(c.image_size)

MNIST dataset

203class MNISTDataset(torchvision.datasets.MNIST):
208    def __init__(self, image_size):
209        transform = torchvision.transforms.Compose([
210            torchvision.transforms.Resize(image_size),
211            torchvision.transforms.ToTensor(),
212        ])
213
214        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
216    def __getitem__(self, item):
217        return super().__getitem__(item)[0]

Create MNIST dataset

220@option(Configs.dataset, 'MNIST')
221def mnist_dataset(c: Configs):
225    return MNISTDataset(c.image_size)
228def main():

Create experiment

230    experiment.create(name='diffuse', writers={'screen', 'comet'})

Create configurations

233    configs = Configs()

Set configurations. You can override the defaults by passing the values in the dictionary.

236    experiment.configs(configs, {
237        'dataset': 'CelebA',  # 'MNIST'
238        'image_channels': 3,  # 1,
239        'epochs': 100,  # 5,
240    })

Initialize

243    configs.init()

Set models for saving and loading

246    experiment.add_pytorch_models({'eps_model': configs.eps_model})

Start and run the training loop

249    with experiment.start():
250        configs.run()

254if __name__ == '__main__':
255    main()