This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA
folder.
The paper had used a exponential moving average of the model with a decay of . We have skipped this for simplicity.
18from typing import List
19
20import torch
21import torch.utils.data
22import torchvision
23from PIL import Image
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs, option
27from labml_helpers.device import DeviceConfigs
28from labml_nn.diffusion.ddpm import DenoiseDiffusion
29from labml_nn.diffusion.ddpm.unet import UNet
32class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
39 device: torch.device = DeviceConfigs()
U-Net model for
42 eps_model: UNet
44 diffusion: DenoiseDiffusion
Number of channels in the image. for RGB.
47 image_channels: int = 3
Image size
49 image_size: int = 32
Number of channels in the initial feature map
51 n_channels: int = 64
The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels
54 channel_multipliers: List[int] = [1, 2, 2, 4]
The list of booleans that indicate whether to use attention at each resolution
56 is_attention: List[int] = [False, False, False, True]
Number of time steps
59 n_steps: int = 1_000
Batch size
61 batch_size: int = 64
Number of samples to generate
63 n_samples: int = 16
Learning rate
65 learning_rate: float = 2e-5
Number of training epochs
68 epochs: int = 1_000
Dataset
71 dataset: torch.utils.data.Dataset
Dataloader
73 data_loader: torch.utils.data.DataLoader
Adam optimizer
76 optimizer: torch.optim.Adam
78 def init(self):
Create model
80 self.eps_model = UNet(
81 image_channels=self.image_channels,
82 n_channels=self.n_channels,
83 ch_mults=self.channel_multipliers,
84 is_attn=self.is_attention,
85 ).to(self.device)
Create DDPM class
88 self.diffusion = DenoiseDiffusion(
89 eps_model=self.eps_model,
90 n_steps=self.n_steps,
91 device=self.device,
92 )
Create dataloader
95 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
Create optimizer
97 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
Image logging
100 tracker.set_image("sample", True)
102 def sample(self):
106 with torch.no_grad():
108 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
109 device=self.device)
Remove noise for steps
112 for t_ in monit.iterate('Sample', self.n_steps):
114 t = self.n_steps - t_ - 1
Sample from
116 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
Log samples
119 tracker.save('sample', x)
121 def train(self):
Iterate through the dataset
127 for data in monit.iterate('Train', self.data_loader):
Increment global step
129 tracker.add_global_step()
Move data to device
131 data = data.to(self.device)
Make the gradients zero
134 self.optimizer.zero_grad()
Calculate loss
136 loss = self.diffusion.loss(data)
Compute gradients
138 loss.backward()
Take an optimization step
140 self.optimizer.step()
Track the loss
142 tracker.save('loss', loss)
144 def run(self):
148 for _ in monit.loop(self.epochs):
Train the model
150 self.train()
Sample some images
152 self.sample()
New line in the console
154 tracker.new_line()
Save the model
156 experiment.save_checkpoint()
159class CelebADataset(torch.utils.data.Dataset):
164 def __init__(self, image_size: int):
165 super().__init__()
CelebA images folder
168 folder = lab.get_data_path() / 'celebA'
List of files
170 self._files = [p for p in folder.glob(f'**/*.jpg')]
Transformations to resize the image and convert to tensor
173 self._transform = torchvision.transforms.Compose([
174 torchvision.transforms.Resize(image_size),
175 torchvision.transforms.ToTensor(),
176 ])
Size of the dataset
178 def __len__(self):
182 return len(self._files)
Get an image
184 def __getitem__(self, index: int):
188 img = Image.open(self._files[index])
189 return self._transform(img)
Create CelebA dataset
192@option(Configs.dataset, 'CelebA')
193def celeb_dataset(c: Configs):
197 return CelebADataset(c.image_size)
200class MNISTDataset(torchvision.datasets.MNIST):
205 def __init__(self, image_size):
206 transform = torchvision.transforms.Compose([
207 torchvision.transforms.Resize(image_size),
208 torchvision.transforms.ToTensor(),
209 ])
210
211 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
213 def __getitem__(self, item):
214 return super().__getitem__(item)[0]
Create MNIST dataset
217@option(Configs.dataset, 'MNIST')
218def mnist_dataset(c: Configs):
222 return MNISTDataset(c.image_size)
225def main():
Create experiment
227 experiment.create(name='diffuse')
Create configurations
230 configs = Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
233 experiment.configs(configs, {
234 })
Initialize
237 configs.init()
Set models for saving and loading
240 experiment.add_pytorch_models({'eps_model': configs.eps_model})
Start and run the training loop
243 with experiment.start():
244 configs.run()
248if __name__ == '__main__':
249 main()