mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
329 lines
12 KiB
Python
329 lines
12 KiB
Python
"""
|
|
---
|
|
title: Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling
|
|
summary: >
|
|
Code to generate samples from a trained
|
|
Denoising Diffusion Probabilistic Model.
|
|
---
|
|
|
|
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) evaluation/sampling
|
|
|
|
This is the code to generate images and create interpolations between given images.
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
from matplotlib import pyplot as plt
|
|
from torchvision.transforms.functional import to_pil_image, resize
|
|
|
|
from labml import experiment, monit
|
|
from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
|
|
from labml_nn.diffusion.ddpm.experiment import Configs
|
|
|
|
|
|
class Sampler:
|
|
"""
|
|
## Sampler class
|
|
"""
|
|
|
|
def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
|
|
"""
|
|
* `diffusion` is the `DenoiseDiffusion` instance
|
|
* `image_channels` is the number of channels in the image
|
|
* `image_size` is the image size
|
|
* `device` is the device of the model
|
|
"""
|
|
self.device = device
|
|
self.image_size = image_size
|
|
self.image_channels = image_channels
|
|
self.diffusion = diffusion
|
|
|
|
# $T$
|
|
self.n_steps = diffusion.n_steps
|
|
# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
|
|
self.eps_model = diffusion.eps_model
|
|
# $\beta_t$
|
|
self.beta = diffusion.beta
|
|
# $\alpha_t$
|
|
self.alpha = diffusion.alpha
|
|
# $\bar\alpha_t$
|
|
self.alpha_bar = diffusion.alpha_bar
|
|
# $\bar\alpha_{t-1}$
|
|
alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
|
|
|
|
# To calculate
|
|
#
|
|
# \begin{align}
|
|
# q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
|
|
# \tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
|
|
# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
|
|
# \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t
|
|
# \end{align}
|
|
|
|
# $$\tilde\beta_t = \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t$$
|
|
self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
|
|
# $$\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$$
|
|
self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
|
|
# $$\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1}}{1-\bar\alpha_t}$$
|
|
self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
|
|
# $\sigma^2 = \beta$
|
|
self.sigma2 = self.beta
|
|
|
|
def show_image(self, img, title=""):
|
|
"""Helper function to display an image"""
|
|
img = img.clip(0, 1)
|
|
img = img.cpu().numpy()
|
|
plt.imshow(img.transpose(1, 2, 0))
|
|
plt.title(title)
|
|
plt.show()
|
|
|
|
def make_video(self, frames, path="video.mp4"):
|
|
"""Helper function to create a video"""
|
|
import imageio
|
|
# 20 second video
|
|
writer = imageio.get_writer(path, fps=len(frames) // 20)
|
|
# Add each image
|
|
for f in frames:
|
|
f = f.clip(0, 1)
|
|
f = to_pil_image(resize(f, [368, 368]))
|
|
writer.append_data(np.array(f))
|
|
#
|
|
writer.close()
|
|
|
|
def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
|
|
"""
|
|
#### Sample an image step-by-step using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
|
|
|
|
We sample an image step-by-step using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$ and at each step
|
|
show the estimate
|
|
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
|
|
\Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
|
|
"""
|
|
|
|
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
|
xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
|
|
|
|
# Interval to log $\hat{x}_0$
|
|
interval = self.n_steps // n_frames
|
|
# Frames for video
|
|
frames = []
|
|
# Sample $T$ steps
|
|
for t_inv in monit.iterate('Denoise', self.n_steps):
|
|
# $t$
|
|
t_ = self.n_steps - t_inv - 1
|
|
# $t$ in a tensor
|
|
t = xt.new_full((1,), t_, dtype=torch.long)
|
|
# $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$
|
|
eps_theta = self.eps_model(xt, t)
|
|
if t_ % interval == 0:
|
|
# Get $\hat{x}_0$ and add to frames
|
|
x0 = self.p_x0(xt, t, eps_theta)
|
|
frames.append(x0[0])
|
|
if not create_video:
|
|
self.show_image(x0[0], f"{t_}")
|
|
# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
|
|
xt = self.p_sample(xt, t, eps_theta)
|
|
|
|
# Make video
|
|
if create_video:
|
|
self.make_video(frames)
|
|
|
|
def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
|
|
"""
|
|
#### Interpolate two images $x_0$ and $x'_0$
|
|
|
|
We get $x_t \sim q(x_t|x_0)$ and $x'_t \sim q(x'_t|x_0)$.
|
|
|
|
Then interpolate to
|
|
$$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
|
|
|
|
Then get
|
|
$$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$
|
|
|
|
* `x1` is $x_0$
|
|
* `x2` is $x'_0$
|
|
* `lambda_` is $\lambda$
|
|
* `t_` is $t$
|
|
"""
|
|
|
|
# Number of samples
|
|
n_samples = x1.shape[0]
|
|
# $t$ tensor
|
|
t = torch.full((n_samples,), t_, device=self.device)
|
|
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
|
|
xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
|
|
|
|
# $$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$
|
|
return self._sample_x0(xt, t_)
|
|
|
|
def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
|
|
create_video=True):
|
|
"""
|
|
#### Interpolate two images $x_0$ and $x'_0$ and make a video
|
|
|
|
* `x1` is $x_0$
|
|
* `x2` is $x'_0$
|
|
* `n_frames` is the number of frames for the image
|
|
* `t_` is $t$
|
|
* `create_video` specifies whether to make a video or to show each frame
|
|
"""
|
|
|
|
# Show original images
|
|
self.show_image(x1, "x1")
|
|
self.show_image(x2, "x2")
|
|
# Add batch dimension
|
|
x1 = x1[None, :, :, :]
|
|
x2 = x2[None, :, :, :]
|
|
# $t$ tensor
|
|
t = torch.full((1,), t_, device=self.device)
|
|
# $x_t \sim q(x_t|x_0)$
|
|
x1t = self.diffusion.q_sample(x1, t)
|
|
# $x'_t \sim q(x'_t|x_0)$
|
|
x2t = self.diffusion.q_sample(x2, t)
|
|
|
|
frames = []
|
|
# Get frames with different $\lambda$
|
|
for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
|
|
# $\lambda$
|
|
lambda_ = i / n_frames
|
|
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
|
|
xt = (1 - lambda_) * x1t + lambda_ * x2t
|
|
# $$\bar{x}_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|\bar{x}_t)$$
|
|
x0 = self._sample_x0(xt, t_)
|
|
# Add to frames
|
|
frames.append(x0[0])
|
|
# Show frame
|
|
if not create_video:
|
|
self.show_image(x0[0], f"{lambda_ :.2f}")
|
|
|
|
# Make video
|
|
if create_video:
|
|
self.make_video(frames)
|
|
|
|
def _sample_x0(self, xt: torch.Tensor, n_steps: int):
|
|
"""
|
|
#### Sample an image using $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
|
|
|
|
* `xt` is $x_t$
|
|
* `n_steps` is $t$
|
|
"""
|
|
|
|
# Number of sampels
|
|
n_samples = xt.shape[0]
|
|
# Iterate until $t$ steps
|
|
for t_ in monit.iterate('Denoise', n_steps):
|
|
t = n_steps - t_ - 1
|
|
# Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
|
|
xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
|
|
|
|
# Return $x_0$
|
|
return xt
|
|
|
|
def sample(self, n_samples: int = 16):
|
|
"""
|
|
#### Generate images
|
|
"""
|
|
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
|
xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
|
|
|
|
# $$x_0 \sim \textcolor{lightgreen}{p_\theta}(x_0|x_t)$$
|
|
x0 = self._sample_x0(xt, self.n_steps)
|
|
|
|
# Show images
|
|
for i in range(n_samples):
|
|
self.show_image(x0[i])
|
|
|
|
def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
|
|
"""
|
|
#### Sample from $\textcolor{lightgreen}{p_\theta}(x_{t-1}|x_t)$
|
|
|
|
\begin{align}
|
|
\textcolor{lightgreen}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
|
|
\textcolor{lightgreen}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
|
|
\textcolor{lightgreen}{\mu_\theta}(x_t, t)
|
|
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
|
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)
|
|
\end{align}
|
|
"""
|
|
# [gather](utils.html) $\bar\alpha_t$
|
|
alpha_bar = gather(self.alpha_bar, t)
|
|
# $\alpha_t$
|
|
alpha = gather(self.alpha, t)
|
|
# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
|
|
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
|
|
# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
|
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
|
|
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
|
|
# $\sigma^2$
|
|
var = gather(self.sigma2, t)
|
|
|
|
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
|
|
eps = torch.randn(xt.shape, device=xt.device)
|
|
# Sample
|
|
return mean + (var ** .5) * eps
|
|
|
|
def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
|
|
"""
|
|
#### Estimate $x_0$
|
|
|
|
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
|
|
\Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
|
|
"""
|
|
# [gather](utils.html) $\bar\alpha_t$
|
|
alpha_bar = gather(self.alpha_bar, t)
|
|
|
|
# $$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
|
|
# \Big( x_t - \sqrt{1 - \bar\alpha_t} \textcolor{lightgreen}{\epsilon_\theta}(x_t, t) \Big)$$
|
|
return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
|
|
|
|
|
|
def main():
|
|
"""Generate samples"""
|
|
|
|
# Training experiment run UUID
|
|
run_uuid = "a44333ea251411ec8007d1a1762ed686"
|
|
|
|
# Start an evaluation
|
|
experiment.evaluate()
|
|
|
|
# Create configs
|
|
configs = Configs()
|
|
# Load custom configuration of the training run
|
|
configs_dict = experiment.load_configs(run_uuid)
|
|
# Set configurations
|
|
experiment.configs(configs, configs_dict)
|
|
|
|
# Initialize
|
|
configs.init()
|
|
|
|
# Set PyTorch modules for saving and loading
|
|
experiment.add_pytorch_models({'eps_model': configs.eps_model})
|
|
|
|
# Load training experiment
|
|
experiment.load(run_uuid)
|
|
|
|
# Create sampler
|
|
sampler = Sampler(diffusion=configs.diffusion,
|
|
image_channels=configs.image_channels,
|
|
image_size=configs.image_size,
|
|
device=configs.device)
|
|
|
|
# Start evaluation
|
|
with experiment.start():
|
|
# No gradients
|
|
with torch.no_grad():
|
|
# Sample an image with an denoising animation
|
|
sampler.sample_animation()
|
|
|
|
if False:
|
|
# Get some images fro data
|
|
data = next(iter(configs.data_loader)).to(configs.device)
|
|
|
|
# Create an interpolation animation
|
|
sampler.interpolate_animate(data[0], data[1])
|
|
|
|
|
|
#
|
|
if __name__ == '__main__':
|
|
main()
|