Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling

This is the code to generate images and create interpolations between given images.

14import numpy as np
15import torch
16from matplotlib import pyplot as plt
17from torchvision.transforms.functional import to_pil_image, resize
18
19from labml import experiment, monit
20from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
21from labml_nn.diffusion.ddpm.experiment import Configs

Sampler class

24class Sampler:
  • diffusion is the DenoiseDiffusion instance
  • image_channels is the number of channels in the image
  • image_size is the image size
  • device is the device of the model
29    def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
36        self.device = device
37        self.image_size = image_size
38        self.image_channels = image_channels
39        self.diffusion = diffusion

42        self.n_steps = diffusion.n_steps

44        self.eps_model = diffusion.eps_model

46        self.beta = diffusion.beta

48        self.alpha = diffusion.alpha

50        self.alpha_bar = diffusion.alpha_bar

52        alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])

To calculate begin{align} q(x_{t-1}|x_t, x_0) &= mathcal{N} Big(x_{t-1}; tildemu_t(x_t, x_0), tildebeta_t mathbf{I} Big) \ tildemu_t(x_t, x_0) &= frac{sqrt{baralpha_{t-1}}beta_t}{1 - baralpha_t}x_0 + frac{sqrt{alpha_t}(1 - baralpha_{t-1})}{1-baralpha_t}x_t \ tildebeta_t &= frac{1 - baralpha_{t-1}}{a} end{align}

63        self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)

65        self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)

67        self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)

69        self.sigma2 = self.beta

Helper function to display an image

71    def show_image(self, img, title=""):
73        img = img.clip(0, 1)
74        img = img.cpu().numpy()
75        plt.imshow(img.transpose(1, 2, 0))
76        plt.title(title)
77        plt.show()

Helper function to create a video

79    def make_video(self, frames, path="video.mp4"):
81        import imageio

20 second video

83        writer = imageio.get_writer(path, fps=len(frames) // 20)

Add each image

85        for f in frames:
86            f = f.clip(0, 1)
87            f = to_pil_image(resize(f, [368, 368]))
88            writer.append_data(np.array(f))

90        writer.close()

Sample an image step-by-step using

We sample an image step-by-step using and at each step show the estimate

92    def sample_animation(self, n_frames: int = 1000, create_video: bool = True):

103        xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)

Interval to log

106        interval = self.n_steps // n_frames

Frames for video

108        frames = []

Sample steps

110        for t_inv in monit.iterate('Denoise', self.n_steps):

112            t_ = self.n_steps - t_inv - 1

in a tensor

114            t = xt.new_full((1,), t_, dtype=torch.long)

116            eps_theta = self.eps_model(xt, t)
117            if t_ % interval == 0:

Get and add to frames

119                x0 = self.p_x0(xt, t, eps_theta)
120                frames.append(x0[0])
121                if not create_video:
122                    self.show_image(x0[0], f"{t_}")

Sample from

124            xt = self.p_sample(xt, t, eps_theta)

Make video

127        if create_video:
128            self.make_video(frames)

Interpolate two images and

We get and .

Then interpolate to

Then get

  • x1 is
  • x2 is
  • lambda_ is
  • t_ is
130    def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):

Number of samples

149        n_samples = x1.shape[0]

tensor

151        t = torch.full((n_samples,), t_, device=self.device)

153        xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)

156        return self._sample_x0(xt, t_)

Interpolate two images and and make a video

  • x1 is
  • x2 is
  • n_frames is the number of frames for the image
  • t_ is
  • create_video specifies whether to make a video or to show each frame
158    def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
159                            create_video=True):

Show original images

171        self.show_image(x1, "x1")
172        self.show_image(x2, "x2")

Add batch dimension

174        x1 = x1[None, :, :, :]
175        x2 = x2[None, :, :, :]

tensor

177        t = torch.full((1,), t_, device=self.device)

179        x1t = self.diffusion.q_sample(x1, t)

181        x2t = self.diffusion.q_sample(x2, t)
182
183        frames = []

Get frames with different

185        for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):

187            lambda_ = i / n_frames

189            xt = (1 - lambda_) * x1t + lambda_ * x2t

191            x0 = self._sample_x0(xt, t_)

Add to frames

193            frames.append(x0[0])

Show frame

195            if not create_video:
196                self.show_image(x0[0], f"{lambda_ :.2f}")

Make video

199        if create_video:
200            self.make_video(frames)

Sample an image using

  • xt is
  • n_steps is
202    def _sample_x0(self, xt: torch.Tensor, n_steps: int):

Number of sampels

211        n_samples = xt.shape[0]

Iterate until steps

213        for t_ in monit.iterate('Denoise', n_steps):
214            t = n_steps - t_ - 1

Sample from

216            xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))

Return

219        return xt

Generate images

221    def sample(self, n_samples: int = 16):

226        xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)

229        x0 = self._sample_x0(xt, self.n_steps)

Show images

232        for i in range(n_samples):
233            self.show_image(x0[i])

Sample from

235    def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):

gather

248        alpha_bar = gather(self.alpha_bar, t)

250        alpha = gather(self.alpha, t)

252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

257        var = gather(self.sigma2, t)

260        eps = torch.randn(xt.shape, device=xt.device)

Sample

262        return mean + (var ** .5) * eps

Estimate

264    def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):

gather

272        alpha_bar = gather(self.alpha_bar, t)

276        return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)

Generate samples

279def main():

Training experiment run UUID

283    run_uuid = "a44333ea251411ec8007d1a1762ed686"

Start an evaluation

286    experiment.evaluate()

Create configs

289    configs = Configs()

Load custom configuration of the training run

291    configs_dict = experiment.load_configs(run_uuid)

Set configurations

293    experiment.configs(configs, configs_dict)

Initialize

296    configs.init()

Set PyTorch modules for saving and loading

299    experiment.add_pytorch_models({'eps_model': configs.eps_model})

Load training experiment

302    experiment.load(run_uuid)

Create sampler

305    sampler = Sampler(diffusion=configs.diffusion,
306                      image_channels=configs.image_channels,
307                      image_size=configs.image_size,
308                      device=configs.device)

Start evaluation

311    with experiment.start():

No gradients

313        with torch.no_grad():

Sample an image with an denoising animation

315            sampler.sample_animation()
316
317            if False:

Get some images fro data

319                data = next(iter(configs.data_loader)).to(configs.device)

Create an interpolation animation

322                sampler.interpolate_animate(data[0], data[1])

326if __name__ == '__main__':
327    main()