mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
253 lines
7.2 KiB
Python
253 lines
7.2 KiB
Python
"""
|
|
---
|
|
title: Denoising Diffusion Probabilistic Models (DDPM) training
|
|
summary: >
|
|
Training code for
|
|
Denoising Diffusion Probabilistic Model.
|
|
---
|
|
|
|
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
|
|
|
|
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/diffusion/ddpm/experiment.ipynb)
|
|
|
|
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
|
|
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
|
|
Save the images inside [`data/celebA` folder](#dataset_path).
|
|
|
|
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
|
|
simplicity.
|
|
"""
|
|
from typing import List
|
|
|
|
import torchvision
|
|
from PIL import Image
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
from labml import lab, tracker, experiment, monit
|
|
from labml.configs import BaseConfigs, option
|
|
from labml_nn.diffusion.ddpm import DenoiseDiffusion
|
|
from labml_nn.diffusion.ddpm.unet import UNet
|
|
from labml_nn.helpers.device import DeviceConfigs
|
|
|
|
|
|
class Configs(BaseConfigs):
|
|
"""
|
|
## Configurations
|
|
"""
|
|
# Device to train the model on.
|
|
# [`DeviceConfigs`](../../device.html)
|
|
# picks up an available CUDA device or defaults to CPU.
|
|
device: torch.device = DeviceConfigs()
|
|
|
|
# U-Net model for $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
|
|
eps_model: UNet
|
|
# [DDPM algorithm](index.html)
|
|
diffusion: DenoiseDiffusion
|
|
|
|
# Number of channels in the image. $3$ for RGB.
|
|
image_channels: int = 3
|
|
# Image size
|
|
image_size: int = 32
|
|
# Number of channels in the initial feature map
|
|
n_channels: int = 64
|
|
# The list of channel numbers at each resolution.
|
|
# The number of channels is `channel_multipliers[i] * n_channels`
|
|
channel_multipliers: List[int] = [1, 2, 2, 4]
|
|
# The list of booleans that indicate whether to use attention at each resolution
|
|
is_attention: List[int] = [False, False, False, True]
|
|
|
|
# Number of time steps $T$
|
|
n_steps: int = 1_000
|
|
# Batch size
|
|
batch_size: int = 64
|
|
# Number of samples to generate
|
|
n_samples: int = 16
|
|
# Learning rate
|
|
learning_rate: float = 2e-5
|
|
|
|
# Number of training epochs
|
|
epochs: int = 1_000
|
|
|
|
# Dataset
|
|
dataset: torch.utils.data.Dataset
|
|
# Dataloader
|
|
data_loader: torch.utils.data.DataLoader
|
|
|
|
# Adam optimizer
|
|
optimizer: torch.optim.Adam
|
|
|
|
def init(self):
|
|
# Create $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$ model
|
|
self.eps_model = UNet(
|
|
image_channels=self.image_channels,
|
|
n_channels=self.n_channels,
|
|
ch_mults=self.channel_multipliers,
|
|
is_attn=self.is_attention,
|
|
).to(self.device)
|
|
|
|
# Create [DDPM class](index.html)
|
|
self.diffusion = DenoiseDiffusion(
|
|
eps_model=self.eps_model,
|
|
n_steps=self.n_steps,
|
|
device=self.device,
|
|
)
|
|
|
|
# Create dataloader
|
|
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
|
|
# Create optimizer
|
|
self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
|
|
|
|
# Image logging
|
|
tracker.set_image("sample", True)
|
|
|
|
def sample(self):
|
|
"""
|
|
### Sample images
|
|
"""
|
|
with torch.no_grad():
|
|
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
|
x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
|
|
device=self.device)
|
|
|
|
# Remove noise for $T$ steps
|
|
for t_ in monit.iterate('Sample', self.n_steps):
|
|
# $t$
|
|
t = self.n_steps - t_ - 1
|
|
# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
|
|
x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
|
|
|
|
# Log samples
|
|
tracker.save('sample', x)
|
|
|
|
def train(self):
|
|
"""
|
|
### Train
|
|
"""
|
|
|
|
# Iterate through the dataset
|
|
for data in monit.iterate('Train', self.data_loader):
|
|
# Increment global step
|
|
tracker.add_global_step()
|
|
# Move data to device
|
|
data = data.to(self.device)
|
|
|
|
# Make the gradients zero
|
|
self.optimizer.zero_grad()
|
|
# Calculate loss
|
|
loss = self.diffusion.loss(data)
|
|
# Compute gradients
|
|
loss.backward()
|
|
# Take an optimization step
|
|
self.optimizer.step()
|
|
# Track the loss
|
|
tracker.save('loss', loss)
|
|
|
|
def run(self):
|
|
"""
|
|
### Training loop
|
|
"""
|
|
for _ in monit.loop(self.epochs):
|
|
# Train the model
|
|
self.train()
|
|
# Sample some images
|
|
self.sample()
|
|
# New line in the console
|
|
tracker.new_line()
|
|
|
|
|
|
class CelebADataset(torch.utils.data.Dataset):
|
|
"""
|
|
### CelebA HQ dataset
|
|
"""
|
|
|
|
def __init__(self, image_size: int):
|
|
super().__init__()
|
|
|
|
# CelebA images folder
|
|
folder = lab.get_data_path() / 'celebA'
|
|
# List of files
|
|
self._files = [p for p in folder.glob(f'**/*.jpg')]
|
|
|
|
# Transformations to resize the image and convert to tensor
|
|
self._transform = torchvision.transforms.Compose([
|
|
torchvision.transforms.Resize(image_size),
|
|
torchvision.transforms.ToTensor(),
|
|
])
|
|
|
|
def __len__(self):
|
|
"""
|
|
Size of the dataset
|
|
"""
|
|
return len(self._files)
|
|
|
|
def __getitem__(self, index: int):
|
|
"""
|
|
Get an image
|
|
"""
|
|
img = Image.open(self._files[index])
|
|
return self._transform(img)
|
|
|
|
|
|
@option(Configs.dataset, 'CelebA')
|
|
def celeb_dataset(c: Configs):
|
|
"""
|
|
Create CelebA dataset
|
|
"""
|
|
return CelebADataset(c.image_size)
|
|
|
|
|
|
class MNISTDataset(torchvision.datasets.MNIST):
|
|
"""
|
|
### MNIST dataset
|
|
"""
|
|
|
|
def __init__(self, image_size):
|
|
transform = torchvision.transforms.Compose([
|
|
torchvision.transforms.Resize(image_size),
|
|
torchvision.transforms.ToTensor(),
|
|
])
|
|
|
|
super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
|
|
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)[0]
|
|
|
|
|
|
@option(Configs.dataset, 'MNIST')
|
|
def mnist_dataset(c: Configs):
|
|
"""
|
|
Create MNIST dataset
|
|
"""
|
|
return MNISTDataset(c.image_size)
|
|
|
|
|
|
def main():
|
|
# Create experiment
|
|
experiment.create(name='diffuse', writers={'screen', 'labml'})
|
|
|
|
# Create configurations
|
|
configs = Configs()
|
|
|
|
# Set configurations. You can override the defaults by passing the values in the dictionary.
|
|
experiment.configs(configs, {
|
|
'dataset': 'CelebA', # 'MNIST'
|
|
'image_channels': 3, # 1,
|
|
'epochs': 100, # 5,
|
|
})
|
|
|
|
# Initialize
|
|
configs.init()
|
|
|
|
# Set models for saving and loading
|
|
experiment.add_pytorch_models({'eps_model': configs.eps_model})
|
|
|
|
# Start and run the training loop
|
|
with experiment.start():
|
|
configs.run()
|
|
|
|
|
|
#
|
|
if __name__ == '__main__':
|
|
main()
|