This is the code to generate images and create interpolations between given images.
14import numpy as np
15import torch
16from matplotlib import pyplot as plt
17from torchvision.transforms.functional import to_pil_image, resize
18
19from labml import experiment, monit
20from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
21from labml_nn.diffusion.ddpm.experiment import Configs
24class Sampler:
diffusion
is the DenoiseDiffusion
instance image_channels
is the number of channels in the image image_size
is the image size device
is the device of the model29 def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
36 self.device = device
37 self.image_size = image_size
38 self.image_channels = image_channels
39 self.diffusion = diffusion
42 self.n_steps = diffusion.n_steps
44 self.eps_model = diffusion.eps_model
46 self.beta = diffusion.beta
48 self.alpha = diffusion.alpha
50 self.alpha_bar = diffusion.alpha_bar
52 alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
To calculate begin{align} q(x_{t-1}|x_t, x_0) &= mathcal{N} Big(x_{t-1}; tildemu_t(x_t, x_0), tildebeta_t mathbf{I} Big) \ tildemu_t(x_t, x_0) &= frac{sqrt{baralpha_{t-1}}beta_t}{1 - baralpha_t}x_0 + frac{sqrt{alpha_t}(1 - baralpha_{t-1})}{1-baralpha_t}x_t \ tildebeta_t &= frac{1 - baralpha_{t-1}}{a} end{align}
63 self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
65 self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
67 self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
69 self.sigma2 = self.beta
Helper function to display an image
71 def show_image(self, img, title=""):
73 img = img.clip(0, 1)
74 img = img.cpu().numpy()
75 plt.imshow(img.transpose(1, 2, 0))
76 plt.title(title)
77 plt.show()
Helper function to create a video
79 def make_video(self, frames, path="video.mp4"):
81 import imageio
20 second video
83 writer = imageio.get_writer(path, fps=len(frames) // 20)
Add each image
85 for f in frames:
86 f = f.clip(0, 1)
87 f = to_pil_image(resize(f, [368, 368]))
88 writer.append_data(np.array(f))
90 writer.close()
We sample an image step-by-step using and at each step show the estimate
92 def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
103 xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
Interval to log
106 interval = self.n_steps // n_frames
Frames for video
108 frames = []
Sample steps
110 for t_inv in monit.iterate('Denoise', self.n_steps):
112 t_ = self.n_steps - t_inv - 1
in a tensor
114 t = xt.new_full((1,), t_, dtype=torch.long)
116 eps_theta = self.eps_model(xt, t)
117 if t_ % interval == 0:
Get and add to frames
119 x0 = self.p_x0(xt, t, eps_theta)
120 frames.append(x0[0])
121 if not create_video:
122 self.show_image(x0[0], f"{t_}")
Sample from
124 xt = self.p_sample(xt, t, eps_theta)
Make video
127 if create_video:
128 self.make_video(frames)
130 def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
Number of samples
149 n_samples = x1.shape[0]
tensor
151 t = torch.full((n_samples,), t_, device=self.device)
153 xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
156 return self._sample_x0(xt, t_)
x1
is x2
is n_frames
is the number of frames for the image t_
is create_video
specifies whether to make a video or to show each frame158 def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
159 create_video=True):
Show original images
171 self.show_image(x1, "x1")
172 self.show_image(x2, "x2")
Add batch dimension
174 x1 = x1[None, :, :, :]
175 x2 = x2[None, :, :, :]
tensor
177 t = torch.full((1,), t_, device=self.device)
179 x1t = self.diffusion.q_sample(x1, t)
181 x2t = self.diffusion.q_sample(x2, t)
182
183 frames = []
Get frames with different
185 for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
187 lambda_ = i / n_frames
189 xt = (1 - lambda_) * x1t + lambda_ * x2t
191 x0 = self._sample_x0(xt, t_)
Add to frames
193 frames.append(x0[0])
Show frame
195 if not create_video:
196 self.show_image(x0[0], f"{lambda_ :.2f}")
Make video
199 if create_video:
200 self.make_video(frames)
202 def _sample_x0(self, xt: torch.Tensor, n_steps: int):
Number of sampels
211 n_samples = xt.shape[0]
Iterate until steps
213 for t_ in monit.iterate('Denoise', n_steps):
214 t = n_steps - t_ - 1
Sample from
216 xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
Return
219 return xt
221 def sample(self, n_samples: int = 16):
226 xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
229 x0 = self._sample_x0(xt, self.n_steps)
Show images
232 for i in range(n_samples):
233 self.show_image(x0[i])
235 def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
250 alpha = gather(self.alpha, t)
252 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
255 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
257 var = gather(self.sigma2, t)
260 eps = torch.randn(xt.shape, device=xt.device)
Sample
262 return mean + (var ** .5) * eps
264 def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
276 return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
Generate samples
279def main():
Training experiment run UUID
283 run_uuid = "a44333ea251411ec8007d1a1762ed686"
Start an evaluation
286 experiment.evaluate()
Create configs
289 configs = Configs()
Load custom configuration of the training run
291 configs_dict = experiment.load_configs(run_uuid)
Set configurations
293 experiment.configs(configs, configs_dict)
Initialize
296 configs.init()
Set PyTorch modules for saving and loading
299 experiment.add_pytorch_models({'eps_model': configs.eps_model})
Load training experiment
302 experiment.load(run_uuid)
Create sampler
305 sampler = Sampler(diffusion=configs.diffusion,
306 image_channels=configs.image_channels,
307 image_size=configs.image_size,
308 device=configs.device)
Start evaluation
311 with experiment.start():
No gradients
313 with torch.no_grad():
Sample an image with an denoising animation
315 sampler.sample_animation()
316
317 if False:
Get some images fro data
319 data = next(iter(configs.data_loader)).to(configs.device)
Create an interpolation animation
322 sampler.interpolate_animate(data[0], data[1])
326if __name__ == '__main__':
327 main()