This is a U-Net based model to predict noise .
U-Net is a gets it's name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution.

This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings .
24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module33class Swish(Module):40    def forward(self, x):
41        return x * torch.sigmoid(x)44class TimeEmbedding(nn.Module):n_channels
 is the number of dimensions in the embedding49    def __init__(self, n_channels: int):53        super().__init__()
54        self.n_channels = n_channelsFirst linear layer
56        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)Activation
58        self.act = Swish()Second linear layer
60        self.lin2 = nn.Linear(self.n_channels, self.n_channels)62    def forward(self, t: torch.Tensor):72        half_dim = self.n_channels // 8
73        emb = math.log(10_000) / (half_dim - 1)
74        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75        emb = t[:, None] * emb[None, :]
76        emb = torch.cat((emb.sin(), emb.cos()), dim=1)Transform with the MLP
79        emb = self.act(self.lin1(emb))
80        emb = self.lin2(emb)83        return embA residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks.
86class ResidualBlock(Module):in_channels
 is the number of input channels out_channels
 is the number of input channels time_channels
 is the number channels in the time step () embeddings n_groups
 is the number of groups for group normalization94    def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):101        super().__init__()Group normalization and the first convolution layer
103        self.norm1 = nn.GroupNorm(n_groups, in_channels)
104        self.act1 = Swish()
105        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))Group normalization and the second convolution layer
108        self.norm2 = nn.GroupNorm(n_groups, out_channels)
109        self.act2 = Swish()
110        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))If the number of input channels is not equal to the number of output channels we have to project the shortcut connection
114        if in_channels != out_channels:
115            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
116        else:
117            self.shortcut = nn.Identity()Linear layer for time embeddings
120        self.time_emb = nn.Linear(time_channels, out_channels)x
 has shape [batch_size, in_channels, height, width]
 t
 has shape [batch_size, time_channels]
122    def forward(self, x: torch.Tensor, t: torch.Tensor):First convolution layer
128        h = self.conv1(self.act1(self.norm1(x)))Add time embeddings
130        h += self.time_emb(t)[:, :, None, None]Second convolution layer
132        h = self.conv2(self.act2(self.norm2(h)))Add the shortcut connection and return
135        return h + self.shortcut(x)138class AttentionBlock(Module):n_channels
 is the number of channels in the input n_heads
 is the number of heads in multi-head attention d_k
 is the number of dimensions in each head n_groups
 is the number of groups for group normalization145    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):152        super().__init__()Default d_k
 
155        if d_k is None:
156            d_k = n_channelsNormalization layer
158        self.norm = nn.GroupNorm(n_groups, n_channels)Projections for query, key and values
160        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)Linear layer for final transformation
162        self.output = nn.Linear(n_heads * d_k, n_channels)Scale for dot-product attention
164        self.scale = d_k ** -0.5166        self.n_heads = n_heads
167        self.d_k = d_kx
 has shape [batch_size, in_channels, height, width]
 t
 has shape [batch_size, time_channels]
169    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):t
 is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
. 
176        _ = tGet shape
178        batch_size, n_channels, height, width = x.shapeChange x
 to shape [batch_size, seq, n_channels]
 
180        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)Get query, key, and values (concatenated) and shape it to [batch_size, seq, n_heads, 3 * d_k]
 
182        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)Split query, key, and values. Each of them will have shape [batch_size, seq, n_heads, d_k]
 
184        q, k, v = torch.chunk(qkv, 3, dim=-1)Calculate scaled dot-product
186        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scaleSoftmax along the sequence dimension
188        attn = attn.softmax(dim=2)Multiply by values
190        res = torch.einsum('bijh,bjhd->bihd', attn, v)Reshape to [batch_size, seq, n_heads * d_k]
 
192        res = res.view(batch_size, -1, self.n_heads * self.d_k)Transform to [batch_size, seq, n_channels]
 
194        res = self.output(res)Add skip connection
197        res += xChange to shape [batch_size, in_channels, height, width]
 
200        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)203        return resThis combines ResidualBlock
 and AttentionBlock
. These are used in the first half of U-Net at each resolution.
206class DownBlock(Module):213    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
214        super().__init__()
215        self.res = ResidualBlock(in_channels, out_channels, time_channels)
216        if has_attn:
217            self.attn = AttentionBlock(out_channels)
218        else:
219            self.attn = nn.Identity()221    def forward(self, x: torch.Tensor, t: torch.Tensor):
222        x = self.res(x, t)
223        x = self.attn(x)
224        return xThis combines ResidualBlock
 and AttentionBlock
. These are used in the second half of U-Net at each resolution.
227class UpBlock(Module):234    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
235        super().__init__()The input has in_channels + out_channels
 because we concatenate the output of the same resolution from the first half of the U-Net 
238        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
239        if has_attn:
240            self.attn = AttentionBlock(out_channels)
241        else:
242            self.attn = nn.Identity()244    def forward(self, x: torch.Tensor, t: torch.Tensor):
245        x = self.res(x, t)
246        x = self.attn(x)
247        return xIt combines a ResidualBlock
, AttentionBlock
, followed by another ResidualBlock
. This block is applied at the lowest resolution of the U-Net.
250class MiddleBlock(Module):258    def __init__(self, n_channels: int, time_channels: int):
259        super().__init__()
260        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
261        self.attn = AttentionBlock(n_channels)
262        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)264    def forward(self, x: torch.Tensor, t: torch.Tensor):
265        x = self.res1(x, t)
266        x = self.attn(x)
267        x = self.res2(x, t)
268        return x271class Upsample(nn.Module):276    def __init__(self, n_channels):
277        super().__init__()
278        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))280    def forward(self, x: torch.Tensor, t: torch.Tensor):t
 is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
. 
283        _ = t
284        return self.conv(x)287class Downsample(nn.Module):292    def __init__(self, n_channels):
293        super().__init__()
294        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))296    def forward(self, x: torch.Tensor, t: torch.Tensor):t
 is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
. 
299        _ = t
300        return self.conv(x)303class UNet(Module):image_channels
 is the number of channels in the image.  for RGB. n_channels
 is number of channels in the initial feature map that we transform the image into ch_mults
 is the list of channel numbers at each resolution. The number of channels is ch_mults[i] * n_channels
 is_attn
 is a list of booleans that indicate whether to use attention at each resolution n_blocks
 is the number of UpDownBlocks
 at each resolution308    def __init__(self, image_channels: int = 3, n_channels: int = 64,
309                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
310                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
311                 n_blocks: int = 2):319        super().__init__()Number of resolutions
322        n_resolutions = len(ch_mults)Project image into feature map
325        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))Time embedding layer. Time embedding has n_channels * 4
 channels 
328        self.time_emb = TimeEmbedding(n_channels * 4)331        down = []Number of channels
333        out_channels = in_channels = n_channelsFor each resolution
335        for i in range(n_resolutions):Number of output channels at this resolution
337            out_channels = in_channels * ch_mults[i]Add n_blocks
 
339            for _ in range(n_blocks):
340                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
341                in_channels = out_channelsDown sample at all resolutions except the last
343            if i < n_resolutions - 1:
344                down.append(Downsample(in_channels))Combine the set of modules
347        self.down = nn.ModuleList(down)Middle block
350        self.middle = MiddleBlock(out_channels, n_channels * 4, )353        up = []Number of channels
355        in_channels = out_channelsFor each resolution
357        for i in reversed(range(n_resolutions)):n_blocks
 at the same resolution 
359            out_channels = in_channels
360            for _ in range(n_blocks):
361                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))Final block to reduce the number of channels
363            out_channels = in_channels // ch_mults[i]
364            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
365            in_channels = out_channelsUp sample at all resolutions except last
367            if i > 0:
368                up.append(Upsample(in_channels))Combine the set of modules
371        self.up = nn.ModuleList(up)Final normalization and convolution layer
374        self.norm = nn.GroupNorm(8, n_channels)
375        self.act = Swish()
376        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))x
 has shape [batch_size, in_channels, height, width]
 t
 has shape [batch_size]
378    def forward(self, x: torch.Tensor, t: torch.Tensor):Get time-step embeddings
385        t = self.time_emb(t)Get image projection
388        x = self.image_proj(x)h
 will store outputs at each resolution for skip connection 
391        h = [x]First half of U-Net
393        for m in self.down:
394            x = m(x, t)
395            h.append(x)Middle (bottom)
398        x = self.middle(x, t)Second half of U-Net
401        for m in self.up:
402            if isinstance(m, Upsample):
403                x = m(x, t)
404            else:Get the skip connection from first half of U-Net and concatenate
406                s = h.pop()
407                x = torch.cat((x, s), dim=1)409                x = m(x, t)Final normalization and convolution
412        return self.final(self.act(self.norm(x)))