This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.
In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.
The following definitions and derivations show how this works. For details please refer to the paper.
The forward process adds noise to the data , for timesteps.
where is the variance schedule.
We can sample at any timestep with,
where and
The reverse process removes noise starting at for time steps.
are the parameters we train.
We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.
The loss can be rewritten as follows.
is constant since we keep constant.
The forward process posterior conditioned by is,
The paper sets where is set to constants or .
Then,
For given noise using
This gives,
Re-parameterizing with a model to predict noise
where is a learned function that predicts given .
This gives,
That is, we are training to predict the noise.
This minimizes when and for discarding the weighting in . Discarding the weights increase the weight given to higher (which have higher noise levels), therefore increasing the sample quality.
This file implements the loss calculation and a basic sampling method that we use to generate images during training.
Here is the UNet model that gives and training code. This file can generate samples and interpolations from a trained model.
162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather172class DenoiseDiffusion:eps_model
 is  model n_steps
 is  device
 is the device to place constants on177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):183        super().__init__()
184        self.eps_model = eps_modelCreate linearly increasing variance schedule
187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)190        self.alpha = 1. - self.beta192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)194        self.n_steps = n_steps196        self.sigma2 = self.beta198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:210        var = 1 - gather(self.alpha_bar, t)212        return mean, var214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):224        if eps is None:
225            eps = torch.randn_like(x0)get
228        mean, var = self.q_xt_x0(x0, t)Sample from
230        return mean + (var ** 0.5) * eps232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):246        eps_theta = self.eps_model(xt, t)250        alpha = gather(self.alpha, t)252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)257        var = gather(self.sigma2, t)260        eps = torch.randn(xt.shape, device=xt.device)Sample
262        return mean + (var ** .5) * eps264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):Get batch size
273        batch_size = x0.shape[0]Get random for each sample in the batch
275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)278        if noise is None:
279            noise = torch.randn_like(x0)Sample for
282        xt = self.q_sample(x0, t, eps=noise)Get
284        eps_theta = self.eps_model(xt, t)MSE loss
287        return F.mse_loss(noise, eps_theta)