Files
2025-07-20 09:13:11 +05:30

253 lines
7.2 KiB
Python

"""
---
title: Denoising Diffusion Probabilistic Models (DDPM) training
summary: >
Training code for
Denoising Diffusion Probabilistic Model.
---
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/diffusion/ddpm/experiment.ipynb)
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
Save the images inside [`data/celebA` folder](#dataset_path).
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
simplicity.
"""
from typing import List
import torchvision
from PIL import Image
import torch
import torch.utils.data
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs, option
from labml_nn.diffusion.ddpm import DenoiseDiffusion
from labml_nn.diffusion.ddpm.unet import UNet
from labml_nn.helpers.device import DeviceConfigs
class Configs(BaseConfigs):
"""
## Configurations
"""
# Device to train the model on.
# [`DeviceConfigs`](../../device.html)
# picks up an available CUDA device or defaults to CPU.
device: torch.device = DeviceConfigs()
# U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
eps_model: UNet
# [DDPM algorithm](index.html)
diffusion: DenoiseDiffusion
# Number of channels in the image. $3$ for RGB.
image_channels: int = 3
# Image size
image_size: int = 32
# Number of channels in the initial feature map
n_channels: int = 64
# The list of channel numbers at each resolution.
# The number of channels is `channel_multipliers[i] * n_channels`
channel_multipliers: List[int] = [1, 2, 2, 4]
# The list of booleans that indicate whether to use attention at each resolution
is_attention: List[int] = [False, False, False, True]
# Number of time steps $T$
n_steps: int = 1_000
# Batch size
batch_size: int = 64
# Number of samples to generate
n_samples: int = 16
# Learning rate
learning_rate: float = 2e-5
# Number of training epochs
epochs: int = 1_000
# Dataset
dataset: torch.utils.data.Dataset
# Dataloader
data_loader: torch.utils.data.DataLoader
# Adam optimizer
optimizer: torch.optim.Adam
def init(self):
# Create $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model
self.eps_model = UNet(
image_channels=self.image_channels,
n_channels=self.n_channels,
ch_mults=self.channel_multipliers,
is_attn=self.is_attention,
).to(self.device)
# Create [DDPM class](index.html)
self.diffusion = DenoiseDiffusion(
eps_model=self.eps_model,
n_steps=self.n_steps,
device=self.device,
)
# Create dataloader
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
# Create optimizer
self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
# Image logging
tracker.set_image("sample", True)
def sample(self):
"""
### Sample images
"""
with torch.no_grad():
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
device=self.device)
# Remove noise for $T$ steps
for t_ in monit.iterate('Sample', self.n_steps):
# $t$
t = self.n_steps - t_ - 1
# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
# Log samples
tracker.save('sample', x)
def train(self):
"""
### Train
"""
# Iterate through the dataset
for data in monit.iterate('Train', self.data_loader):
# Increment global step
tracker.add_global_step()
# Move data to device
data = data.to(self.device)
# Make the gradients zero
self.optimizer.zero_grad()
# Calculate loss
loss = self.diffusion.loss(data)
# Compute gradients
loss.backward()
# Take an optimization step
self.optimizer.step()
# Track the loss
tracker.save('loss', loss)
def run(self):
"""
### Training loop
"""
for _ in monit.loop(self.epochs):
# Train the model
self.train()
# Sample some images
self.sample()
# New line in the console
tracker.new_line()
class CelebADataset(torch.utils.data.Dataset):
"""
### CelebA HQ dataset
"""
def __init__(self, image_size: int):
super().__init__()
# CelebA images folder
folder = lab.get_data_path() / 'celebA'
# List of files
self._files = [p for p in folder.glob(f'**/*.jpg')]
# Transformations to resize the image and convert to tensor
self._transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_size),
torchvision.transforms.ToTensor(),
])
def __len__(self):
"""
Size of the dataset
"""
return len(self._files)
def __getitem__(self, index: int):
"""
Get an image
"""
img = Image.open(self._files[index])
return self._transform(img)
@option(Configs.dataset, 'CelebA')
def celeb_dataset(c: Configs):
"""
Create CelebA dataset
"""
return CelebADataset(c.image_size)
class MNISTDataset(torchvision.datasets.MNIST):
"""
### MNIST dataset
"""
def __init__(self, image_size):
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_size),
torchvision.transforms.ToTensor(),
])
super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
def __getitem__(self, item):
return super().__getitem__(item)[0]
@option(Configs.dataset, 'MNIST')
def mnist_dataset(c: Configs):
"""
Create MNIST dataset
"""
return MNISTDataset(c.image_size)
def main():
# Create experiment
experiment.create(name='diffuse', writers={'screen', 'labml'})
# Create configurations
configs = Configs()
# Set configurations. You can override the defaults by passing the values in the dictionary.
experiment.configs(configs, {
'dataset': 'CelebA', # 'MNIST'
'image_channels': 3, # 1,
'epochs': 100, # 5,
})
# Initialize
configs.init()
# Set models for saving and loading
experiment.add_pytorch_models({'eps_model': configs.eps_model})
# Start and run the training loop
with experiment.start():
configs.run()
#
if __name__ == '__main__':
main()