diff --git a/docs/diffusion/ddpm/evaluate.html b/docs/diffusion/ddpm/evaluate.html new file mode 100644 index 00000000..ad988108 --- /dev/null +++ b/docs/diffusion/ddpm/evaluate.html @@ -0,0 +1,1198 @@ + + +
+ + + + + + + + + + + + + + + + + + + +This is the code to generate images and create interpolations between given images.
+14import numpy as np
+15import torch
+16from matplotlib import pyplot as plt
+17from torchvision.transforms.functional import to_pil_image, resize
+18
+19from labml import experiment, monit
+20from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
+21from labml_nn.diffusion.ddpm.experiment import Configs
24class Sampler:
diffusion
is the DenoiseDiffusion
instanceimage_channels
is the number of channels in the imageimage_size
is the image sizedevice
is the device of the model29 def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
36 self.device = device
+37 self.image_size = image_size
+38 self.image_channels = image_channels
+39 self.diffusion = diffusion
$T$
+42 self.n_steps = diffusion.n_steps
$\color{cyan}{\epsilon_\theta}(x_t, t)$
+44 self.eps_model = diffusion.eps_model
$\beta_t$
+46 self.beta = diffusion.beta
$\alpha_t$
+48 self.alpha = diffusion.alpha
$\bar\alpha_t$
+50 self.alpha_bar = diffusion.alpha_bar
$\bar\alpha_{t-1}$
+52 alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
To calculate + +
+$\tilde\beta_t$
+63 self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
+ +
+65 self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
+ +
+67 self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
$\sigma^2 = \beta$
+69 self.sigma2 = self.beta
Helper function to display an image
+71 def show_image(self, img, title=""):
73 img = img.clip(0, 1)
+74 img = img.cpu().numpy()
+75 plt.imshow(img.transpose(1, 2, 0))
+76 plt.title(title)
+77 plt.show()
Helper function to create a video
+79 def make_video(self, frames, path="video.mp4"):
81 import imageio
20 second video
+83 writer = imageio.get_writer(path, fps=len(frames) // 20)
Add each image
+85 for f in frames:
+86 f = f.clip(0, 1)
+87 f = to_pil_image(resize(f, [368, 368]))
+88 writer.append_data(np.array(f))
90 writer.close()
We sample an image step-by-step using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$ and at each step +show the estimate + +
+92 def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
+103 xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
Interval to log $\hat{x}_0$
+106 interval = self.n_steps // n_frames
Frames for video
+108 frames = []
Sample $T$ steps
+110 for t_inv in monit.iterate('Denoise', self.n_steps):
$t$
+112 t_ = self.n_steps - t_inv - 1
$t$ in a tensor
+114 t = xt.new_full((1,), t_, dtype=torch.long)
$\color{cyan}{\epsilon_\theta}(x_t, t)$
+116 eps_theta = self.eps_model(xt, t)
+117 if t_ % interval == 0:
Get $\hat{x}_0$ and add to frames
+119 x0 = self.p_x0(xt, t, eps_theta)
+120 frames.append(x0[0])
+121 if not create_video:
+122 self.show_image(x0[0], f"{t_}")
Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
+124 xt = self.p_sample(xt, t, eps_theta)
Make video
+127 if create_video:
+128 self.make_video(frames)
We get $x_t \sim q(x_t|x_0)$ and $x’_t \sim q(x’_t|x_0)$.
+Then interpolate to + +
+Then get + +
+x1
is $x_0$x2
is $x’_0$lambda_
is $\lambda$t_
is $t$130 def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
Number of samples
+149 n_samples = x1.shape[0]
$t$ tensor
+151 t = torch.full((n_samples,), t_, device=self.device)
+ +
+153 xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
+ +
+156 return self._sample_x0(xt, t_)
x1
is $x_0$x2
is $x’_0$n_frames
is the number of frames for the imaget_
is $t$create_video
specifies whether to make a video or to show each frame158 def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
+159 create_video=True):
Show original images
+171 self.show_image(x1, "x1")
+172 self.show_image(x2, "x2")
Add batch dimension
+174 x1 = x1[None, :, :, :]
+175 x2 = x2[None, :, :, :]
$t$ tensor
+177 t = torch.full((1,), t_, device=self.device)
$x_t \sim q(x_t|x_0)$
+179 x1t = self.diffusion.q_sample(x1, t)
$x’_t \sim q(x’_t|x_0)$
+181 x2t = self.diffusion.q_sample(x2, t)
+182
+183 frames = []
Get frames with different $\lambda$
+185 for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
$\lambda$
+187 lambda_ = i / n_frames
+ +
+189 xt = (1 - lambda_) * x1t + lambda_ * x2t
+ +
+191 x0 = self._sample_x0(xt, t_)
Add to frames
+193 frames.append(x0[0])
Show frame
+195 if not create_video:
+196 self.show_image(x0[0], f"{lambda_ :.2f}")
Make video
+199 if create_video:
+200 self.make_video(frames)
xt
is $x_t$n_steps
is $t$202 def _sample_x0(self, xt: torch.Tensor, n_steps: int):
Number of sampels
+211 n_samples = xt.shape[0]
Iterate until $t$ steps
+213 for t_ in monit.iterate('Denoise', n_steps):
+214 t = n_steps - t_ - 1
Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
+216 xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
Return $x_0$
+219 return xt
221 def sample(self, n_samples: int = 16):
$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
+226 xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
+ +
+229 x0 = self._sample_x0(xt, self.n_steps)
Show images
+232 for i in range(n_samples):
+233 self.show_image(x0[i])
235 def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
$\alpha_t$
+250 alpha = gather(self.alpha, t)
$\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
+252 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
+ +
+255 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
$\sigma^2$
+257 var = gather(self.sigma2, t)
$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
+260 eps = torch.randn(xt.shape, device=xt.device)
Sample
+262 return mean + (var ** .5) * eps
264 def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
+ +
+276 return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
Generate samples
+279def main():
Training experiment run UUID
+283 run_uuid = "a44333ea251411ec8007d1a1762ed686"
Start an evaluation
+286 experiment.evaluate()
Create configs
+289 configs = Configs()
Load custom configuration of the training run
+291 configs_dict = experiment.load_configs(run_uuid)
Set configurations
+293 experiment.configs(configs, configs_dict)
Initialize
+296 configs.init()
Set PyTorch modules for saving and loading
+299 experiment.add_pytorch_models({'eps_model': configs.eps_model})
Load training experiment
+302 experiment.load(run_uuid)
Create sampler
+305 sampler = Sampler(diffusion=configs.diffusion,
+306 image_channels=configs.image_channels,
+307 image_size=configs.image_size,
+308 device=configs.device)
Start evaluation
+311 with experiment.start():
No gradients
+313 with torch.no_grad():
Sample an image with an denoising animation
+315 sampler.sample_animation()
+316
+317 if False:
Get some images fro data
+319 data = next(iter(configs.data_loader)).to(configs.device)
Create an interpolation animation
+322 sampler.interpolate_animate(data[0], data[1])
326if __name__ == '__main__':
+327 main()
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
+discussion on fast.ai.
+Save the images inside data/celebA
folder.
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for +simplicity.
+18from typing import List
+19
+20import torch
+21import torch.utils.data
+22import torchvision
+23from PIL import Image
+24
+25from labml import lab, tracker, experiment, monit
+26from labml.configs import BaseConfigs, option
+27from labml_helpers.device import DeviceConfigs
+28from labml_nn.diffusion.ddpm import DenoiseDiffusion
+29from labml_nn.diffusion.ddpm.unet import UNet
32class Configs(BaseConfigs):
Device to train the model on.
+DeviceConfigs
+ picks up an available CUDA device or defaults to CPU.
39 device: torch.device = DeviceConfigs()
U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$
+42 eps_model: UNet
44 diffusion: DenoiseDiffusion
Number of channels in the image. $3$ for RGB.
+47 image_channels: int = 3
Image size
+49 image_size: int = 32
Number of channels in the initial feature map
+51 n_channels: int = 64
The list of channel numbers at each resolution.
+The number of channels is channel_multipliers[i] * n_channels
54 channel_multipliers: List[int] = [1, 2, 2, 4]
The list of booleans that indicate whether to use attention at each resolution
+56 is_attention: List[int] = [False, False, False, True]
Number of time steps $T$
+59 n_steps: int = 1_000
Batch size
+61 batch_size: int = 64
Number of samples to generate
+63 n_samples: int = 16
Learning rate
+65 learning_rate: float = 2e-5
Number of training epochs
+68 epochs: int = 1_000
Dataset
+71 dataset: torch.utils.data.Dataset
Dataloader
+73 data_loader: torch.utils.data.DataLoader
Adam optimizer
+76 optimizer: torch.optim.Adam
78 def init(self):
Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model
+80 self.eps_model = UNet(
+81 image_channels=self.image_channels,
+82 n_channels=self.n_channels,
+83 ch_mults=self.channel_multipliers,
+84 is_attn=self.is_attention,
+85 ).to(self.device)
Create DDPM class
+88 self.diffusion = DenoiseDiffusion(
+89 eps_model=self.eps_model,
+90 n_steps=self.n_steps,
+91 device=self.device,
+92 )
Create dataloader
+95 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
Create optimizer
+97 self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
Image logging
+100 tracker.set_image("sample", True)
102 def sample(self):
106 with torch.no_grad():
$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
+108 x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
+109 device=self.device)
Remove noise for $T$ steps
+112 for t_ in monit.iterate('Sample', self.n_steps):
$t$
+114 t = self.n_steps - t_ - 1
Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
+116 x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
Log samples
+119 tracker.save('sample', x)
121 def train(self):
Iterate through the dataset
+127 for data in monit.iterate('Train', self.data_loader):
Increment global step
+129 tracker.add_global_step()
Move data to device
+131 data = data.to(self.device)
Make the gradients zero
+134 self.optimizer.zero_grad()
Calculate loss
+136 loss = self.diffusion.loss(data)
Compute gradients
+138 loss.backward()
Take an optimization step
+140 self.optimizer.step()
Track the loss
+142 tracker.save('loss', loss)
144 def run(self):
148 for _ in monit.loop(self.epochs):
Train the model
+150 self.train()
Sample some images
+152 self.sample()
New line in the console
+154 tracker.new_line()
Save the model
+156 experiment.save_checkpoint()
159class CelebADataset(torch.utils.data.Dataset):
164 def __init__(self, image_size: int):
+165 super().__init__()
CelebA images folder
+168 folder = lab.get_data_path() / 'celebA'
List of files
+170 self._files = [p for p in folder.glob(f'**/*.jpg')]
Transformations to resize the image and convert to tensor
+173 self._transform = torchvision.transforms.Compose([
+174 torchvision.transforms.Resize(image_size),
+175 torchvision.transforms.ToTensor(),
+176 ])
Size of the dataset
+178 def __len__(self):
182 return len(self._files)
Get an image
+184 def __getitem__(self, index: int):
188 img = Image.open(self._files[index])
+189 return self._transform(img)
Create CelebA dataset
+192@option(Configs.dataset, 'CelebA')
+193def celeb_dataset(c: Configs):
197 return CelebADataset(c.image_size)
200class MNISTDataset(torchvision.datasets.MNIST):
205 def __init__(self, image_size):
+206 transform = torchvision.transforms.Compose([
+207 torchvision.transforms.Resize(image_size),
+208 torchvision.transforms.ToTensor(),
+209 ])
+210
+211 super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
213 def __getitem__(self, item):
+214 return super().__getitem__(item)[0]
Create MNIST dataset
+217@option(Configs.dataset, 'MNIST')
+218def mnist_dataset(c: Configs):
222 return MNISTDataset(c.image_size)
225def main():
Create experiment
+227 experiment.create(name='diffuse')
Create configurations
+230 configs = Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
+233 experiment.configs(configs, {
+234 })
Initialize
+237 configs.init()
Set models for saving and loading
+240 experiment.add_pytorch_models({'eps_model': configs.eps_model})
Start and run the training loop
+243 with experiment.start():
+244 configs.run()
248if __name__ == '__main__':
+249 main()
This is a PyTorch implementation/tutorial of the paper +Denoising Diffusion Probabilistic Models.
+In simple terms, we get an image from data and add noise step by step. +Then We train a model to predict that noise at each step and use the model to +generate images.
+The following definitions and derivations show how this works. +For details please refer to the paper.
+The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.
++ +
+where $\beta_1, \dots, \beta_T$ is the variance schedule.
+We can sample $x_t$ at any timestep $t$ with,
++ +
+where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
+The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$ +for $T$ time steps.
++ +
+$\color{cyan}\theta$ are the parameters we train.
+We optimize the ELBO (from Jenson’s inequality) on the negative log likelihood.
++ +
+The loss can be rewritten as follows.
++ +
+$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.
+The forward process posterior conditioned by $x_0$ is,
++ +
+The paper sets $\color{cyan}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants +$\beta_t$ or $\tilde\beta_t$.
+Then, + +
+For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$
++ +
+This gives,
++ +
+Re-parameterizing with a model to predict noise
++ +
+where $\epsilon_theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.
+This gives,
++ +
+That is, we are training to predict the noise.
++ +
+This minimizes $-\log \color{cyan}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the +weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$ +increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.
+This file implements the loss calculation and a basic sampling method that we use to generate images during +training.
+Here is the UNet model that gives $\color{cyan}{\epsilon_\theta}(x_t, t)$ and +training code. +This file can generate samples and interpolations from a trained model.
+ +162from typing import Tuple, Optional
+163
+164import torch
+165import torch.nn.functional as F
+166import torch.utils.data
+167from torch import nn
+168
+169from labml_nn.diffusion.ddpm.utils import gather
172class DenoiseDiffusion:
eps_model
is $\color{cyan}{\epsilon_\theta}(x_t, t)$ modeln_steps
is $t$device
is the device to place constants on177 def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
183 super().__init__()
+184 self.eps_model = eps_model
Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule
+187 self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
$\alpha_t = 1 - \beta_t$
+190 self.alpha = 1. - self.beta
$\bar\alpha_t = \prod_{s=1}^t \alpha_s$
+192 self.alpha_bar = torch.cumprod(self.alpha, dim=0)
$T$
+194 self.n_steps = n_steps
$\sigma^2 = \beta$
+196 self.sigma2 = self.beta
198 def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
208 mean = gather(self.alpha_bar, t) ** 0.5 * x0
$(1-\bar\alpha_t) \mathbf{I}$
+210 var = 1 - gather(self.alpha_bar, t)
212 return mean, var
214 def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
+224 if eps is None:
+225 eps = torch.randn_like(x0)
get $q(x_t|x_0)$
+228 mean, var = self.q_xt_x0(x0, t)
Sample from $q(x_t|x_0)$
+230 return mean + (var ** 0.5) * eps
232 def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
$\color{cyan}{\epsilon_\theta}(x_t, t)$
+246 eps_theta = self.eps_model(xt, t)
$\alpha_t$
+250 alpha = gather(self.alpha, t)
$\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
+252 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
+ +
+255 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
$\sigma^2$
+257 var = gather(self.sigma2, t)
$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
+260 eps = torch.randn(xt.shape, device=xt.device)
Sample
+262 return mean + (var ** .5) * eps
264 def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
Get batch size
+273 batch_size = x0.shape[0]
Get random $t$ for each sample in the batch
+275 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
+278 if noise is None:
+279 noise = torch.randn_like(x0)
Sample $x_t$ for $q(x_t|x_0)$
+282 xt = self.q_sample(x0, t, eps=noise)
Get $\color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$
+284 eps_theta = self.eps_model(xt, t)
MSE loss
+287 return F.mse_loss(noise, eps_theta)
This is a PyTorch implementation/tutorial of the paper +Denoising Diffusion Probabilistic Models.
+In simple terms, we get an image from data and add noise step by step. +Then We train a model to predict that noise at each step and use the model to +generate images.
+Here is the UNet model that predicts the noise and +training code. +This file can generate samples and interpolations +from a trained model.
+ +This is a U-Net based model to predict noise +$\color{cyan}{\epsilon_\theta}(x_t, t)$.
+U-Net is a gets it’s name from the U shape in the model diagram. +It processes a given image by progressively lowering (halving) the feature map resolution and then +increasing the resolution. +There are pass-through connection at each resolution.
+This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) + and also adds time-step embeddings $t$.
+24import math
+25from typing import Optional, Tuple, Union, List
+26
+27import torch
+28from torch import nn
+29
+30from labml_helpers.module import Module
33class Swish(Module):
40 def forward(self, x):
+41 return x * torch.sigmoid(x)
44class TimeEmbedding(nn.Module):
n_channels
is the number of dimensions in the embedding49 def __init__(self, n_channels: int):
53 super().__init__()
+54 self.n_channels = n_channels
First linear layer
+56 self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
Activation
+58 self.act = Swish()
Second linear layer
+60 self.lin2 = nn.Linear(self.n_channels, self.n_channels)
62 def forward(self, t: torch.Tensor):
Create sinusoidal position embeddings
+same as those from the transformer
+
+where $d$ is half_dim
70 half_dim = self.n_channels // 8
+71 emb = math.log(10_000) / (half_dim - 1)
+72 emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
+73 emb = t[:, None] * emb[None, :]
+74 emb = torch.cat((emb.sin(), emb.cos()), dim=1)
Transform with the MLP
+77 emb = self.act(self.lin1(emb))
+78 emb = self.lin2(emb)
81 return emb
A residual block has two convolution layers with group normalization. +Each resolution is processed with two residual blocks.
+84class ResidualBlock(Module):
in_channels
is the number of input channelsout_channels
is the number of input channelstime_channels
is the number channels in the time step ($t$) embeddingsn_groups
is the number of groups for group normalization92 def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
99 super().__init__()
Group normalization and the first convolution layer
+101 self.norm1 = nn.GroupNorm(n_groups, in_channels)
+102 self.act1 = Swish()
+103 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
Group normalization and the second convolution layer
+106 self.norm2 = nn.GroupNorm(n_groups, out_channels)
+107 self.act2 = Swish()
+108 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
If the number of input channels is not equal to the number of output channels we have to +project the shortcut connection
+112 if in_channels != out_channels:
+113 self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
+114 else:
+115 self.shortcut = nn.Identity()
Linear layer for time embeddings
+118 self.time_emb = nn.Linear(time_channels, out_channels)
x
has shape [batch_size, in_channels, height, width]
t
has shape [batch_size, time_channels]
120 def forward(self, x: torch.Tensor, t: torch.Tensor):
First convolution layer
+126 h = self.conv1(self.act1(self.norm1(x)))
Add time embeddings
+128 h += self.time_emb(t)[:, :, None, None]
Second convolution layer
+130 h = self.conv2(self.act2(self.norm2(h)))
Add the shortcut connection and return
+133 return h + self.shortcut(x)
136class AttentionBlock(Module):
n_channels
is the number of channels in the inputn_heads
is the number of heads in multi-head attentiond_k
is the number of dimensions in each headn_groups
is the number of groups for group normalization143 def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
150 super().__init__()
Default d_k
153 if d_k is None:
+154 d_k = n_channels
Normalization layer
+156 self.norm = nn.GroupNorm(n_groups, n_channels)
Projections for query, key and values
+158 self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
Linear layer for final transformation
+160 self.output = nn.Linear(n_heads * d_k, n_channels)
Scale for dot-product attention
+162 self.scale = d_k ** -0.5
164 self.n_heads = n_heads
+165 self.d_k = d_k
x
has shape [batch_size, in_channels, height, width]
t
has shape [batch_size, time_channels]
167 def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
t
is not used, but it’s kept in the arguments because for the attention layer function signature
+to match with ResidualBlock
.
174 _ = t
Get shape
+176 batch_size, n_channels, height, width = x.shape
Change x
to shape [batch_size, seq, n_channels]
178 x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
Get query, key, and values (concatenated) and shape it to [batch_size, seq, n_heads, 3 * d_k]
180 qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
Split query, key, and values. Each of them will have shape [batch_size, seq, n_heads, d_k]
182 q, k, v = torch.chunk(qkv, 3, dim=-1)
Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
+184 attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
+186 attn = attn.softmax(dim=1)
Multiply by values
+188 res = torch.einsum('bijh,bjhd->bihd', attn, v)
Reshape to [batch_size, seq, n_heads * d_k]
190 res = res.view(batch_size, -1, self.n_heads * self.d_k)
Transform to [batch_size, seq, n_channels]
192 res = self.output(res)
Add skip connection
+195 res += x
Change to shape [batch_size, in_channels, height, width]
198 res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
201 return res
This combines ResidualBlock
and AttentionBlock
. These are used in the first half of U-Net at each resolution.
204class DownBlock(Module):
211 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
+212 super().__init__()
+213 self.res = ResidualBlock(in_channels, out_channels, time_channels)
+214 if has_attn:
+215 self.attn = AttentionBlock(out_channels)
+216 else:
+217 self.attn = nn.Identity()
219 def forward(self, x: torch.Tensor, t: torch.Tensor):
+220 x = self.res(x, t)
+221 x = self.attn(x)
+222 return x
This combines ResidualBlock
and AttentionBlock
. These are used in the second half of U-Net at each resolution.
225class UpBlock(Module):
232 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
+233 super().__init__()
The input has in_channels + out_channels
because we concatenate the output of the same resolution
+from the first half of the U-Net
236 self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
+237 if has_attn:
+238 self.attn = AttentionBlock(out_channels)
+239 else:
+240 self.attn = nn.Identity()
242 def forward(self, x: torch.Tensor, t: torch.Tensor):
+243 x = self.res(x, t)
+244 x = self.attn(x)
+245 return x
It combines a ResidualBlock
, AttentionBlock
, followed by another ResidualBlock
.
+This block is applied at the lowest resolution of the U-Net.
248class MiddleBlock(Module):
256 def __init__(self, n_channels: int, time_channels: int):
+257 super().__init__()
+258 self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
+259 self.attn = AttentionBlock(n_channels)
+260 self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
262 def forward(self, x: torch.Tensor, t: torch.Tensor):
+263 x = self.res1(x, t)
+264 x = self.attn(x)
+265 x = self.res2(x, t)
+266 return x
269class Upsample(nn.Module):
274 def __init__(self, n_channels):
+275 super().__init__()
+276 self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
278 def forward(self, x: torch.Tensor, t: torch.Tensor):
t
is not used, but it’s kept in the arguments because for the attention layer function signature
+to match with ResidualBlock
.
281 _ = t
+282 return self.conv(x)
285class Downsample(nn.Module):
290 def __init__(self, n_channels):
+291 super().__init__()
+292 self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
294 def forward(self, x: torch.Tensor, t: torch.Tensor):
t
is not used, but it’s kept in the arguments because for the attention layer function signature
+to match with ResidualBlock
.
297 _ = t
+298 return self.conv(x)
301class UNet(Module):
image_channels
is the number of channels in the image. $3$ for RGB.n_channels
is number of channels in the initial feature map that we transform the image intoch_mults
is the list of channel numbers at each resolution. The number of channels is ch_mults[i] * n_channels
is_attn
is a list of booleans that indicate whether to use attention at each resolutionn_blocks
is the number of UpDownBlocks
at each resolution306 def __init__(self, image_channels: int = 3, n_channels: int = 64,
+307 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
+308 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
+309 n_blocks: int = 2):
317 super().__init__()
Number of resolutions
+320 n_resolutions = len(ch_mults)
Project image into feature map
+323 self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
Time embedding layer. Time embedding has n_channels * 4
channels
326 self.time_emb = TimeEmbedding(n_channels * 4)
329 down = []
Number of channels
+331 out_channels = in_channels = n_channels
For each resolution
+333 for i in range(n_resolutions):
Number of output channels at this resolution
+335 out_channels = in_channels * ch_mults[i]
Add n_blocks
337 for _ in range(n_blocks):
+338 down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
+339 in_channels = out_channels
Down sample at all resolutions except the last
+341 if i < n_resolutions - 1:
+342 down.append(Downsample(in_channels))
Combine the set of modules
+345 self.down = nn.ModuleList(down)
Middle block
+348 self.middle = MiddleBlock(out_channels, n_channels * 4, )
351 up = []
Number of channels
+353 in_channels = out_channels
For each resolution
+355 for i in reversed(range(n_resolutions)):
n_blocks
at the same resolution
357 out_channels = in_channels
+358 for _ in range(n_blocks):
+359 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
Final block to reduce the number of channels
+361 out_channels = in_channels // ch_mults[i]
+362 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
+363 in_channels = out_channels
Up sample at all resolutions except last
+365 if i > 0:
+366 up.append(Upsample(in_channels))
Combine the set of modules
+369 self.up = nn.ModuleList(up)
Final normalization and convolution layer
+372 self.norm = nn.GroupNorm(8, n_channels)
+373 self.act = Swish()
+374 self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
x
has shape [batch_size, in_channels, height, width]
t
has shape [batch_size]
376 def forward(self, x: torch.Tensor, t: torch.Tensor):
Get time-step embeddings
+383 t = self.time_emb(t)
Get image projection
+386 x = self.image_proj(x)
h
will store outputs at each resolution for skip connection
389 h = [x]
First half of U-Net
+391 for m in self.down:
+392 x = m(x, t)
+393 h.append(x)
Middle (bottom)
+396 x = self.middle(x, t)
Second half of U-Net
+399 for m in self.up:
+400 if isinstance(m, Upsample):
+401 x = m(x, t)
+402 else:
Get the skip connection from first half of U-Net and concatenate
+404 s = h.pop()
+405 x = torch.cat((x, s), dim=1)
407 x = m(x, t)
Final normalization and convolution
+410 return self.final(self.act(self.norm(x)))
Gather consts for $t$ and reshape to feature map shape
+13def gather(consts: torch.Tensor, t: torch.Tensor):
15 c = consts.gather(-1, t)
+16 return c.reshape(-1, 1, 1, 1)
This is a PyTorch implementation of paper + Playing Atari with Deep Reinforcement Learning + along with Dueling Network, Prioritized Replay + and Double Q Network.
+Here is the experiment and model implementation.
+ +