Denoising Diffusion Probabilistic Models (DDPM) training

This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA folder.

The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.

18from typing import List
19
20import torch
21import torch.utils.data
22import torchvision
23from PIL import Image
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs, option
27from labml_helpers.device import DeviceConfigs
28from labml_nn.diffusion.ddpm import DenoiseDiffusion
29from labml_nn.diffusion.ddpm.unet import UNet

Configurations

32class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

39    device: torch.device = DeviceConfigs()

U-Net model for

42    eps_model: UNet
44    diffusion: DenoiseDiffusion

Number of channels in the image. for RGB.

47    image_channels: int = 3

Image size

49    image_size: int = 32

Number of channels in the initial feature map

51    n_channels: int = 64

The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels

54    channel_multipliers: List[int] = [1, 2, 2, 4]

The list of booleans that indicate whether to use attention at each resolution

56    is_attention: List[int] = [False, False, False, True]

Number of time steps

59    n_steps: int = 1_000

Batch size

61    batch_size: int = 64

Number of samples to generate

63    n_samples: int = 16

Learning rate

65    learning_rate: float = 2e-5

Number of training epochs

68    epochs: int = 1_000

Dataset

71    dataset: torch.utils.data.Dataset

Dataloader

73    data_loader: torch.utils.data.DataLoader

Adam optimizer

76    optimizer: torch.optim.Adam
78    def init(self):

Create model

80        self.eps_model = UNet(
81            image_channels=self.image_channels,
82            n_channels=self.n_channels,
83            ch_mults=self.channel_multipliers,
84            is_attn=self.is_attention,
85        ).to(self.device)

Create DDPM class

88        self.diffusion = DenoiseDiffusion(
89            eps_model=self.eps_model,
90            n_steps=self.n_steps,
91            device=self.device,
92        )

Create dataloader

95        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)

Create optimizer

97        self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)

Image logging

100        tracker.set_image("sample", True)

Sample images

102    def sample(self):
106        with torch.no_grad():

108            x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
109                            device=self.device)

Remove noise for steps

112            for t_ in monit.iterate('Sample', self.n_steps):

114                t = self.n_steps - t_ - 1

Sample from

116                x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))

Log samples

119            tracker.save('sample', x)

Train

121    def train(self):

Iterate through the dataset

127        for data in monit.iterate('Train', self.data_loader):

Increment global step

129            tracker.add_global_step()

Move data to device

131            data = data.to(self.device)

Make the gradients zero

134            self.optimizer.zero_grad()

Calculate loss

136            loss = self.diffusion.loss(data)

Compute gradients

138            loss.backward()

Take an optimization step

140            self.optimizer.step()

Track the loss

142            tracker.save('loss', loss)

Training loop

144    def run(self):
148        for _ in monit.loop(self.epochs):

Train the model

150            self.train()

Sample some images

152            self.sample()

New line in the console

154            tracker.new_line()

Save the model

156            experiment.save_checkpoint()

CelebA HQ dataset

159class CelebADataset(torch.utils.data.Dataset):
164    def __init__(self, image_size: int):
165        super().__init__()

CelebA images folder

168        folder = lab.get_data_path() / 'celebA'

List of files

170        self._files = [p for p in folder.glob(f'**/*.jpg')]

Transformations to resize the image and convert to tensor

173        self._transform = torchvision.transforms.Compose([
174            torchvision.transforms.Resize(image_size),
175            torchvision.transforms.ToTensor(),
176        ])

Size of the dataset

178    def __len__(self):
182        return len(self._files)

Get an image

184    def __getitem__(self, index: int):
188        img = Image.open(self._files[index])
189        return self._transform(img)

Create CelebA dataset

192@option(Configs.dataset, 'CelebA')
193def celeb_dataset(c: Configs):
197    return CelebADataset(c.image_size)

MNIST dataset

200class MNISTDataset(torchvision.datasets.MNIST):
205    def __init__(self, image_size):
206        transform = torchvision.transforms.Compose([
207            torchvision.transforms.Resize(image_size),
208            torchvision.transforms.ToTensor(),
209        ])
210
211        super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
213    def __getitem__(self, item):
214        return super().__getitem__(item)[0]

Create MNIST dataset

217@option(Configs.dataset, 'MNIST')
218def mnist_dataset(c: Configs):
222    return MNISTDataset(c.image_size)
225def main():

Create experiment

227    experiment.create(name='diffuse')

Create configurations

230    configs = Configs()

Set configurations. You can override the defaults by passing the values in the dictionary.

233    experiment.configs(configs, {
234    })

Initialize

237    configs.init()

Set models for saving and loading

240    experiment.add_pytorch_models({'eps_model': configs.eps_model})

Start and run the training loop

243    with experiment.start():
244        configs.run()

248if __name__ == '__main__':
249    main()