This is a U-Net based model to predict noise .
U-Net is a gets it's name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution.

This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings .
24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module33class Swish(Module):40    def forward(self, x):
41        return x * torch.sigmoid(x)44class TimeEmbedding(nn.Module):n_channels
 is the number of dimensions in the embedding49    def __init__(self, n_channels: int):53        super().__init__()
54        self.n_channels = n_channelsFirst linear layer
56        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)Activation
58        self.act = Swish()Second linear layer
60        self.lin2 = nn.Linear(self.n_channels, self.n_channels)62    def forward(self, t: torch.Tensor):72        half_dim = self.n_channels // 8
73        emb = math.log(10_000) / (half_dim - 1)
74        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75        emb = t[:, None] * emb[None, :]
76        emb = torch.cat((emb.sin(), emb.cos()), dim=1)Transform with the MLP
79        emb = self.act(self.lin1(emb))
80        emb = self.lin2(emb)83        return embA residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks.
86class ResidualBlock(Module):in_channels
 is the number of input channels out_channels
 is the number of input channels time_channels
 is the number channels in the time step () embeddings n_groups
 is the number of groups for group normalization dropout
 is the dropout rate94    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
95                 n_groups: int = 32, dropout: float = 0.1):103        super().__init__()Group normalization and the first convolution layer
105        self.norm1 = nn.GroupNorm(n_groups, in_channels)
106        self.act1 = Swish()
107        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))Group normalization and the second convolution layer
110        self.norm2 = nn.GroupNorm(n_groups, out_channels)
111        self.act2 = Swish()
112        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))If the number of input channels is not equal to the number of output channels we have to project the shortcut connection
116        if in_channels != out_channels:
117            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
118        else:
119            self.shortcut = nn.Identity()Linear layer for time embeddings
122        self.time_emb = nn.Linear(time_channels, out_channels)
123        self.time_act = Swish()
124
125        self.dropout = nn.Dropout(dropout)x
 has shape [batch_size, in_channels, height, width]
 t
 has shape [batch_size, time_channels]
127    def forward(self, x: torch.Tensor, t: torch.Tensor):First convolution layer
133        h = self.conv1(self.act1(self.norm1(x)))Add time embeddings
135        h += self.time_emb(self.time_act(t))[:, :, None, None]Second convolution layer
137        h = self.conv2(self.dropout(self.act2(self.norm2(h))))Add the shortcut connection and return
140        return h + self.shortcut(x)143class AttentionBlock(Module):n_channels
 is the number of channels in the input n_heads
 is the number of heads in multi-head attention d_k
 is the number of dimensions in each head n_groups
 is the number of groups for group normalization150    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):157        super().__init__()Default d_k
 
160        if d_k is None:
161            d_k = n_channelsNormalization layer
163        self.norm = nn.GroupNorm(n_groups, n_channels)Projections for query, key and values
165        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)Linear layer for final transformation
167        self.output = nn.Linear(n_heads * d_k, n_channels)Scale for dot-product attention
169        self.scale = d_k ** -0.5171        self.n_heads = n_heads
172        self.d_k = d_kx
 has shape [batch_size, in_channels, height, width]
 t
 has shape [batch_size, time_channels]
174    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):t
 is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
. 
181        _ = tGet shape
183        batch_size, n_channels, height, width = x.shapeChange x
 to shape [batch_size, seq, n_channels]
 
185        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)Get query, key, and values (concatenated) and shape it to [batch_size, seq, n_heads, 3 * d_k]
 
187        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)Split query, key, and values. Each of them will have shape [batch_size, seq, n_heads, d_k]
 
189        q, k, v = torch.chunk(qkv, 3, dim=-1)Calculate scaled dot-product
191        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scaleSoftmax along the sequence dimension
193        attn = attn.softmax(dim=2)Multiply by values
195        res = torch.einsum('bijh,bjhd->bihd', attn, v)Reshape to [batch_size, seq, n_heads * d_k]
 
197        res = res.view(batch_size, -1, self.n_heads * self.d_k)Transform to [batch_size, seq, n_channels]
 
199        res = self.output(res)Add skip connection
202        res += xChange to shape [batch_size, in_channels, height, width]
 
205        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)208        return resThis combines ResidualBlock
 and AttentionBlock
. These are used in the first half of U-Net at each resolution.
211class DownBlock(Module):218    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
219        super().__init__()
220        self.res = ResidualBlock(in_channels, out_channels, time_channels)
221        if has_attn:
222            self.attn = AttentionBlock(out_channels)
223        else:
224            self.attn = nn.Identity()226    def forward(self, x: torch.Tensor, t: torch.Tensor):
227        x = self.res(x, t)
228        x = self.attn(x)
229        return xThis combines ResidualBlock
 and AttentionBlock
. These are used in the second half of U-Net at each resolution.
232class UpBlock(Module):239    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
240        super().__init__()The input has in_channels + out_channels
 because we concatenate the output of the same resolution from the first half of the U-Net 
243        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
244        if has_attn:
245            self.attn = AttentionBlock(out_channels)
246        else:
247            self.attn = nn.Identity()249    def forward(self, x: torch.Tensor, t: torch.Tensor):
250        x = self.res(x, t)
251        x = self.attn(x)
252        return xIt combines a ResidualBlock
, AttentionBlock
, followed by another ResidualBlock
. This block is applied at the lowest resolution of the U-Net.
255class MiddleBlock(Module):263    def __init__(self, n_channels: int, time_channels: int):
264        super().__init__()
265        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
266        self.attn = AttentionBlock(n_channels)
267        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)269    def forward(self, x: torch.Tensor, t: torch.Tensor):
270        x = self.res1(x, t)
271        x = self.attn(x)
272        x = self.res2(x, t)
273        return x276class Upsample(nn.Module):281    def __init__(self, n_channels):
282        super().__init__()
283        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))285    def forward(self, x: torch.Tensor, t: torch.Tensor):t
 is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
. 
288        _ = t
289        return self.conv(x)292class Downsample(nn.Module):297    def __init__(self, n_channels):
298        super().__init__()
299        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))301    def forward(self, x: torch.Tensor, t: torch.Tensor):t
 is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
. 
304        _ = t
305        return self.conv(x)308class UNet(Module):image_channels
 is the number of channels in the image.  for RGB. n_channels
 is number of channels in the initial feature map that we transform the image into ch_mults
 is the list of channel numbers at each resolution. The number of channels is ch_mults[i] * n_channels
 is_attn
 is a list of booleans that indicate whether to use attention at each resolution n_blocks
 is the number of UpDownBlocks
 at each resolution313    def __init__(self, image_channels: int = 3, n_channels: int = 64,
314                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
315                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
316                 n_blocks: int = 2):324        super().__init__()Number of resolutions
327        n_resolutions = len(ch_mults)Project image into feature map
330        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))Time embedding layer. Time embedding has n_channels * 4
 channels 
333        self.time_emb = TimeEmbedding(n_channels * 4)336        down = []Number of channels
338        out_channels = in_channels = n_channelsFor each resolution
340        for i in range(n_resolutions):Number of output channels at this resolution
342            out_channels = in_channels * ch_mults[i]Add n_blocks
 
344            for _ in range(n_blocks):
345                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
346                in_channels = out_channelsDown sample at all resolutions except the last
348            if i < n_resolutions - 1:
349                down.append(Downsample(in_channels))Combine the set of modules
352        self.down = nn.ModuleList(down)Middle block
355        self.middle = MiddleBlock(out_channels, n_channels * 4, )358        up = []Number of channels
360        in_channels = out_channelsFor each resolution
362        for i in reversed(range(n_resolutions)):n_blocks
 at the same resolution 
364            out_channels = in_channels
365            for _ in range(n_blocks):
366                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))Final block to reduce the number of channels
368            out_channels = in_channels // ch_mults[i]
369            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
370            in_channels = out_channelsUp sample at all resolutions except last
372            if i > 0:
373                up.append(Upsample(in_channels))Combine the set of modules
376        self.up = nn.ModuleList(up)Final normalization and convolution layer
379        self.norm = nn.GroupNorm(8, n_channels)
380        self.act = Swish()
381        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))x
 has shape [batch_size, in_channels, height, width]
 t
 has shape [batch_size]
383    def forward(self, x: torch.Tensor, t: torch.Tensor):Get time-step embeddings
390        t = self.time_emb(t)Get image projection
393        x = self.image_proj(x)h
 will store outputs at each resolution for skip connection 
396        h = [x]First half of U-Net
398        for m in self.down:
399            x = m(x, t)
400            h.append(x)Middle (bottom)
403        x = self.middle(x, t)Second half of U-Net
406        for m in self.up:
407            if isinstance(m, Upsample):
408                x = m(x, t)
409            else:Get the skip connection from first half of U-Net and concatenate
411                s = h.pop()
412                x = torch.cat((x, s), dim=1)414                x = m(x, t)Final normalization and convolution
417        return self.final(self.act(self.norm(x)))