This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA
folder.
The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.
21from typing import List
22
23import torch
24import torch.utils.data
25import torchvision
26from PIL import Image
27
28from labml import lab, tracker, experiment, monit
29from labml.configs import BaseConfigs, option
30from labml_helpers.device import DeviceConfigs
31from labml_nn.diffusion.ddpm import DenoiseDiffusion
32from labml_nn.diffusion.ddpm.unet import UNet
35class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
42 device: torch.device = DeviceConfigs()
U-Net model for
45 eps_model: UNet
47 diffusion: DenoiseDiffusion
Number of channels in the image. for RGB.
50 image_channels: int = 3
Image size
52 image_size: int = 32
Number of channels in the initial feature map
54 n_channels: int = 64
The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels
57 channel_multipliers: List[int] = [1, 2, 2, 4]
The list of booleans that indicate whether to use attention at each resolution
59 is_attention: List[int] = [False, False, False, True]
Number of time steps
62 n_steps: int = 1_000
Batch size
64 batch_size: int = 64
Number of samples to generate
66 n_samples: int = 16
Learning rate
68 learning_rate: float = 2e-5
Number of training epochs
71 epochs: int = 1_000
Dataset
74 dataset: torch.utils.data.Dataset
Dataloader
76 data_loader: torch.utils.data.DataLoader
Adam optimizer
79 optimizer: torch.optim.Adam
81 def init(self):
Create model
83 self.eps_model = UNet(
84 image_channels=self.image_channels,
85 n_channels=self.n_channels,
86 ch_mults=self.channel_multipliers,
87 is_attn=self.is_attention,
88 ).to(self.device)
Create DDPM class
91 self.diffusion = DenoiseDiffusion(
92 eps_model=self.eps_model,
93 n_steps=self.n_steps,
94 device=self.device,
95 )
Create dataloader
98 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
Create optimizer
100 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
Image logging
103 tracker.set_image("sample", True)
105 def sample(self):
109 with torch.no_grad():
111 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
112 device=self.device)
Remove noise for steps
115 for t_ in monit.iterate('Sample', self.n_steps):
117 t = self.n_steps - t_ - 1
Sample from
119 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
Log samples
122 tracker.save('sample', x)
124 def train(self):
Iterate through the dataset
130 for data in monit.iterate('Train', self.data_loader):
Increment global step
132 tracker.add_global_step()
Move data to device
134 data = data.to(self.device)
Make the gradients zero
137 self.optimizer.zero_grad()
Calculate loss
139 loss = self.diffusion.loss(data)
Compute gradients
141 loss.backward()
Take an optimization step
143 self.optimizer.step()
Track the loss
145 tracker.save('loss', loss)
147 def run(self):
151 for _ in monit.loop(self.epochs):
Train the model
153 self.train()
Sample some images
155 self.sample()
New line in the console
157 tracker.new_line()
Save the model
159 experiment.save_checkpoint()
162class CelebADataset(torch.utils.data.Dataset):
167 def __init__(self, image_size: int):
168 super().__init__()
CelebA images folder
171 folder = lab.get_data_path() / 'celebA'
List of files
173 self._files = [p for p in folder.glob(f'**/*.jpg')]
Transformations to resize the image and convert to tensor
176 self._transform = torchvision.transforms.Compose([
177 torchvision.transforms.Resize(image_size),
178 torchvision.transforms.ToTensor(),
179 ])
Size of the dataset
181 def __len__(self):
185 return len(self._files)
Get an image
187 def __getitem__(self, index: int):
191 img = Image.open(self._files[index])
192 return self._transform(img)
Create CelebA dataset
195@option(Configs.dataset, 'CelebA')
196def celeb_dataset(c: Configs):
200 return CelebADataset(c.image_size)
203class MNISTDataset(torchvision.datasets.MNIST):
208 def __init__(self, image_size):
209 transform = torchvision.transforms.Compose([
210 torchvision.transforms.Resize(image_size),
211 torchvision.transforms.ToTensor(),
212 ])
213
214 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
216 def __getitem__(self, item):
217 return super().__getitem__(item)[0]
Create MNIST dataset
220@option(Configs.dataset, 'MNIST')
221def mnist_dataset(c: Configs):
225 return MNISTDataset(c.image_size)
228def main():
Create experiment
230 experiment.create(name='diffuse', writers={'screen', 'comet'})
Create configurations
233 configs = Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
236 experiment.configs(configs, {
237 'dataset': 'CelebA', # 'MNIST'
238 'image_channels': 3, # 1,
239 'epochs': 100, # 5,
240 })
Initialize
243 configs.init()
Set models for saving and loading
246 experiment.add_pytorch_models({'eps_model': configs.eps_model})
Start and run the training loop
249 with experiment.start():
250 configs.run()
254if __name__ == '__main__':
255 main()