This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.
In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.
The following definitions and derivations show how this works. For details please refer to the paper.
The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.
where $\beta_1, \dots, \beta_T$ is the variance schedule.
We can sample $x_t$ at any timestep $t$ with,
where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$ for $T$ time steps.
$\color{cyan}\theta$ are the parameters we train.
We optimize the ELBO (from Jenson’s inequality) on the negative log likelihood.
The loss can be rewritten as follows.
$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.
The forward process posterior conditioned by $x_0$ is,
The paper sets $\color{cyan}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants $\beta_t$ or $\tilde\beta_t$.
Then,
For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$
This gives,
Re-parameterizing with a model to predict noise
where $\epsilon_theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.
This gives,
That is, we are training to predict the noise.
This minimizes $-\log \color{cyan}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$ increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.
This file implements the loss calculation and a basic sampling method that we use to generate images during training.
Here is the UNet model that gives $\color{cyan}{\epsilon_\theta}(x_t, t)$ and training code. This file can generate samples and interpolations from a trained model.
162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather
172class DenoiseDiffusion:
eps_model
is $\color{cyan}{\epsilon_\theta}(x_t, t)$ modeln_steps
is $t$device
is the device to place constants on177 def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
183 super().__init__()
184 self.eps_model = eps_model
Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule
187 self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
$\alpha_t = 1 - \beta_t$
190 self.alpha = 1. - self.beta
$\bar\alpha_t = \prod_{s=1}^t \alpha_s$
192 self.alpha_bar = torch.cumprod(self.alpha, dim=0)
$T$
194 self.n_steps = n_steps
$\sigma^2 = \beta$
196 self.sigma2 = self.beta
198 def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
208 mean = gather(self.alpha_bar, t) ** 0.5 * x0
$(1-\bar\alpha_t) \mathbf{I}$
210 var = 1 - gather(self.alpha_bar, t)
212 return mean, var
214 def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
224 if eps is None:
225 eps = torch.randn_like(x0)
get $q(x_t|x_0)$
228 mean, var = self.q_xt_x0(x0, t)
Sample from $q(x_t|x_0)$
230 return mean + (var ** 0.5) * eps
232 def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
$\color{cyan}{\epsilon_\theta}(x_t, t)$
246 eps_theta = self.eps_model(xt, t)
$\alpha_t$
250 alpha = gather(self.alpha, t)
$\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
252 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
255 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
$\sigma^2$
257 var = gather(self.sigma2, t)
$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
260 eps = torch.randn(xt.shape, device=xt.device)
Sample
262 return mean + (var ** .5) * eps
264 def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
Get batch size
273 batch_size = x0.shape[0]
Get random $t$ for each sample in the batch
275 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
278 if noise is None:
279 noise = torch.randn_like(x0)
Sample $x_t$ for $q(x_t|x_0)$
282 xt = self.q_sample(x0, t, eps=noise)
Get $\color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$
284 eps_theta = self.eps_model(xt, t)
MSE loss
287 return F.mse_loss(noise, eps_theta)
Annotated @PyTorch implementation of "Denoising Diffusion Probabilistic Models" by @hojonathanho @ajayj_ @pabbeel @berkeley_ai
— labml.ai (@labmlai) October 9, 2021
๐ Annotated code https://t.co/IxJMNQxJMa
๐ฅ Github https://t.co/he5yIZZlB2
๐ Paper https://t.co/FjpamUVhLI
๐งต๐ pic.twitter.com/5SIZud6OnH