""" --- title: U-Net model for Denoising Diffusion Probabilistic Models (DDPM) summary: > UNet model for Denoising Diffusion Probabilistic Models (DDPM) --- # U-Net model for [Denoising Diffusion Probabilistic Models (DDPM)](index.html) This is a [U-Net](../../unet/index.html) based model to predict noise $\textcolor{lightgreen}{\epsilon_\theta}(x_t, t)$. U-Net is a gets it's name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution. ![U-Net diagram from paper](../../unet/unet.png) This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings $t$. """ import math from typing import Optional, Tuple, Union, List import torch from torch import nn class Swish(nn.Module): """ ### Swish activation function $$x \cdot \sigma(x)$$ """ def forward(self, x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): """ ### Embeddings for $t$ """ def __init__(self, n_channels: int): """ * `n_channels` is the number of dimensions in the embedding """ super().__init__() self.n_channels = n_channels # First linear layer self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels) # Activation self.act = Swish() # Second linear layer self.lin2 = nn.Linear(self.n_channels, self.n_channels) def forward(self, t: torch.Tensor): # Create sinusoidal position embeddings # [same as those from the transformer](../../transformers/positional_encoding.html) # # \begin{align} # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\ # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) # \end{align} # # where $d$ is `half_dim` half_dim = self.n_channels // 8 emb = math.log(10_000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) emb = t[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=1) # Transform with the MLP emb = self.act(self.lin1(emb)) emb = self.lin2(emb) # return emb class ResidualBlock(nn.Module): """ ### Residual block A residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks. """ def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout: float = 0.1): """ * `in_channels` is the number of input channels * `out_channels` is the number of input channels * `time_channels` is the number channels in the time step ($t$) embeddings * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html) * `dropout` is the dropout rate """ super().__init__() # Group normalization and the first convolution layer self.norm1 = nn.GroupNorm(n_groups, in_channels) self.act1 = Swish() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)) # Group normalization and the second convolution layer self.norm2 = nn.GroupNorm(n_groups, out_channels) self.act2 = Swish() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)) # If the number of input channels is not equal to the number of output channels we have to # project the shortcut connection if in_channels != out_channels: self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) else: self.shortcut = nn.Identity() # Linear layer for time embeddings self.time_emb = nn.Linear(time_channels, out_channels) self.time_act = Swish() self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, t: torch.Tensor): """ * `x` has shape `[batch_size, in_channels, height, width]` * `t` has shape `[batch_size, time_channels]` """ # First convolution layer h = self.conv1(self.act1(self.norm1(x))) # Add time embeddings h += self.time_emb(self.time_act(t))[:, :, None, None] # Second convolution layer h = self.conv2(self.dropout(self.act2(self.norm2(h)))) # Add the shortcut connection and return return h + self.shortcut(x) class AttentionBlock(nn.Module): """ ### Attention block This is similar to [transformer multi-head attention](../../transformers/mha.html). """ def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32): """ * `n_channels` is the number of channels in the input * `n_heads` is the number of heads in multi-head attention * `d_k` is the number of dimensions in each head * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html) """ super().__init__() # Default `d_k` if d_k is None: d_k = n_channels # Normalization layer self.norm = nn.GroupNorm(n_groups, n_channels) # Projections for query, key and values self.projection = nn.Linear(n_channels, n_heads * d_k * 3) # Linear layer for final transformation self.output = nn.Linear(n_heads * d_k, n_channels) # Scale for dot-product attention self.scale = d_k ** -0.5 # self.n_heads = n_heads self.d_k = d_k def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None): """ * `x` has shape `[batch_size, in_channels, height, width]` * `t` has shape `[batch_size, time_channels]` """ # `t` is not used, but it's kept in the arguments because for the attention layer function signature # to match with `ResidualBlock`. _ = t # Get shape batch_size, n_channels, height, width = x.shape # Change `x` to shape `[batch_size, seq, n_channels]` x = x.view(batch_size, n_channels, -1).permute(0, 2, 1) # Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]` qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k) # Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]` q, k, v = torch.chunk(qkv, 3, dim=-1) # Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$ attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale # Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$ attn = attn.softmax(dim=2) # Multiply by values res = torch.einsum('bijh,bjhd->bihd', attn, v) # Reshape to `[batch_size, seq, n_heads * d_k]` res = res.view(batch_size, -1, self.n_heads * self.d_k) # Transform to `[batch_size, seq, n_channels]` res = self.output(res) # Add skip connection res += x # Change to shape `[batch_size, in_channels, height, width]` res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width) # return res class DownBlock(nn.Module): """ ### Down block This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution. """ def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool): super().__init__() self.res = ResidualBlock(in_channels, out_channels, time_channels) if has_attn: self.attn = AttentionBlock(out_channels) else: self.attn = nn.Identity() def forward(self, x: torch.Tensor, t: torch.Tensor): x = self.res(x, t) x = self.attn(x) return x class UpBlock(nn.Module): """ ### Up block This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution. """ def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool): super().__init__() # The input has `in_channels + out_channels` because we concatenate the output of the same resolution # from the first half of the U-Net self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels) if has_attn: self.attn = AttentionBlock(out_channels) else: self.attn = nn.Identity() def forward(self, x: torch.Tensor, t: torch.Tensor): x = self.res(x, t) x = self.attn(x) return x class MiddleBlock(nn.Module): """ ### Middle block It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`. This block is applied at the lowest resolution of the U-Net. """ def __init__(self, n_channels: int, time_channels: int): super().__init__() self.res1 = ResidualBlock(n_channels, n_channels, time_channels) self.attn = AttentionBlock(n_channels) self.res2 = ResidualBlock(n_channels, n_channels, time_channels) def forward(self, x: torch.Tensor, t: torch.Tensor): x = self.res1(x, t) x = self.attn(x) x = self.res2(x, t) return x class Upsample(nn.Module): """ ### Scale up the feature map by $2 \times$ """ def __init__(self, n_channels): super().__init__() self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1)) def forward(self, x: torch.Tensor, t: torch.Tensor): # `t` is not used, but it's kept in the arguments because for the attention layer function signature # to match with `ResidualBlock`. _ = t return self.conv(x) class Downsample(nn.Module): """ ### Scale down the feature map by $\frac{1}{2} \times$ """ def __init__(self, n_channels): super().__init__() self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1)) def forward(self, x: torch.Tensor, t: torch.Tensor): # `t` is not used, but it's kept in the arguments because for the attention layer function signature # to match with `ResidualBlock`. _ = t return self.conv(x) class UNet(nn.Module): """ ## U-Net """ def __init__(self, image_channels: int = 3, n_channels: int = 64, ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4), is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, True, True), n_blocks: int = 2): """ * `image_channels` is the number of channels in the image. $3$ for RGB. * `n_channels` is number of channels in the initial feature map that we transform the image into * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels` * `is_attn` is a list of booleans that indicate whether to use attention at each resolution * `n_blocks` is the number of `UpDownBlocks` at each resolution """ super().__init__() # Number of resolutions n_resolutions = len(ch_mults) # Project image into feature map self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1)) # Time embedding layer. Time embedding has `n_channels * 4` channels self.time_emb = TimeEmbedding(n_channels * 4) # #### First half of U-Net - decreasing resolution down = [] # Number of channels out_channels = in_channels = n_channels # For each resolution for i in range(n_resolutions): # Number of output channels at this resolution out_channels = in_channels * ch_mults[i] # Add `n_blocks` for _ in range(n_blocks): down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i])) in_channels = out_channels # Down sample at all resolutions except the last if i < n_resolutions - 1: down.append(Downsample(in_channels)) # Combine the set of modules self.down = nn.ModuleList(down) # Middle block self.middle = MiddleBlock(out_channels, n_channels * 4, ) # #### Second half of U-Net - increasing resolution up = [] # Number of channels in_channels = out_channels # For each resolution for i in reversed(range(n_resolutions)): # `n_blocks` at the same resolution out_channels = in_channels for _ in range(n_blocks): up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i])) # Final block to reduce the number of channels out_channels = in_channels // ch_mults[i] up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i])) in_channels = out_channels # Up sample at all resolutions except last if i > 0: up.append(Upsample(in_channels)) # Combine the set of modules self.up = nn.ModuleList(up) # Final normalization and convolution layer self.norm = nn.GroupNorm(8, n_channels) self.act = Swish() self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1)) def forward(self, x: torch.Tensor, t: torch.Tensor): """ * `x` has shape `[batch_size, in_channels, height, width]` * `t` has shape `[batch_size]` """ # Get time-step embeddings t = self.time_emb(t) # Get image projection x = self.image_proj(x) # `h` will store outputs at each resolution for skip connection h = [x] # First half of U-Net for m in self.down: x = m(x, t) h.append(x) # Middle (bottom) x = self.middle(x, t) # Second half of U-Net for m in self.up: if isinstance(m, Upsample): x = m(x, t) else: # Get the skip connection from first half of U-Net and concatenate s = h.pop() x = torch.cat((x, s), dim=1) # x = m(x, t) # Final normalization and convolution return self.final(self.act(self.norm(x)))