This is a U-Net based model to predict noise .
U-Net is a gets it's name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution.
This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings .
24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
31class Swish(nn.Module):
38 def forward(self, x):
39 return x * torch.sigmoid(x)
42class TimeEmbedding(nn.Module):
n_channels
is the number of dimensions in the embedding47 def __init__(self, n_channels: int):
51 super().__init__()
52 self.n_channels = n_channels
First linear layer
54 self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
Activation
56 self.act = Swish()
Second linear layer
58 self.lin2 = nn.Linear(self.n_channels, self.n_channels)
60 def forward(self, t: torch.Tensor):
70 half_dim = self.n_channels // 8
71 emb = math.log(10_000) / (half_dim - 1)
72 emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
73 emb = t[:, None] * emb[None, :]
74 emb = torch.cat((emb.sin(), emb.cos()), dim=1)
Transform with the MLP
77 emb = self.act(self.lin1(emb))
78 emb = self.lin2(emb)
81 return emb
A residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks.
84class ResidualBlock(nn.Module):
in_channels
is the number of input channels out_channels
is the number of input channels time_channels
is the number channels in the time step () embeddings n_groups
is the number of groups for group normalization dropout
is the dropout rate92 def __init__(self, in_channels: int, out_channels: int, time_channels: int,
93 n_groups: int = 32, dropout: float = 0.1):
101 super().__init__()
Group normalization and the first convolution layer
103 self.norm1 = nn.GroupNorm(n_groups, in_channels)
104 self.act1 = Swish()
105 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
Group normalization and the second convolution layer
108 self.norm2 = nn.GroupNorm(n_groups, out_channels)
109 self.act2 = Swish()
110 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
If the number of input channels is not equal to the number of output channels we have to project the shortcut connection
114 if in_channels != out_channels:
115 self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
116 else:
117 self.shortcut = nn.Identity()
Linear layer for time embeddings
120 self.time_emb = nn.Linear(time_channels, out_channels)
121 self.time_act = Swish()
122
123 self.dropout = nn.Dropout(dropout)
x
has shape [batch_size, in_channels, height, width]
t
has shape [batch_size, time_channels]
125 def forward(self, x: torch.Tensor, t: torch.Tensor):
First convolution layer
131 h = self.conv1(self.act1(self.norm1(x)))
Add time embeddings
133 h += self.time_emb(self.time_act(t))[:, :, None, None]
Second convolution layer
135 h = self.conv2(self.dropout(self.act2(self.norm2(h))))
Add the shortcut connection and return
138 return h + self.shortcut(x)
141class AttentionBlock(nn.Module):
n_channels
is the number of channels in the input n_heads
is the number of heads in multi-head attention d_k
is the number of dimensions in each head n_groups
is the number of groups for group normalization148 def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
155 super().__init__()
Default d_k
158 if d_k is None:
159 d_k = n_channels
Normalization layer
161 self.norm = nn.GroupNorm(n_groups, n_channels)
Projections for query, key and values
163 self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
Linear layer for final transformation
165 self.output = nn.Linear(n_heads * d_k, n_channels)
Scale for dot-product attention
167 self.scale = d_k ** -0.5
169 self.n_heads = n_heads
170 self.d_k = d_k
x
has shape [batch_size, in_channels, height, width]
t
has shape [batch_size, time_channels]
172 def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
t
is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
.
179 _ = t
Get shape
181 batch_size, n_channels, height, width = x.shape
Change x
to shape [batch_size, seq, n_channels]
183 x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
Get query, key, and values (concatenated) and shape it to [batch_size, seq, n_heads, 3 * d_k]
185 qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
Split query, key, and values. Each of them will have shape [batch_size, seq, n_heads, d_k]
187 q, k, v = torch.chunk(qkv, 3, dim=-1)
Calculate scaled dot-product
189 attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
Softmax along the sequence dimension
191 attn = attn.softmax(dim=2)
Multiply by values
193 res = torch.einsum('bijh,bjhd->bihd', attn, v)
Reshape to [batch_size, seq, n_heads * d_k]
195 res = res.view(batch_size, -1, self.n_heads * self.d_k)
Transform to [batch_size, seq, n_channels]
197 res = self.output(res)
Add skip connection
200 res += x
Change to shape [batch_size, in_channels, height, width]
203 res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
206 return res
This combines ResidualBlock
and AttentionBlock
. These are used in the first half of U-Net at each resolution.
209class DownBlock(nn.Module):
216 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
217 super().__init__()
218 self.res = ResidualBlock(in_channels, out_channels, time_channels)
219 if has_attn:
220 self.attn = AttentionBlock(out_channels)
221 else:
222 self.attn = nn.Identity()
224 def forward(self, x: torch.Tensor, t: torch.Tensor):
225 x = self.res(x, t)
226 x = self.attn(x)
227 return x
This combines ResidualBlock
and AttentionBlock
. These are used in the second half of U-Net at each resolution.
230class UpBlock(nn.Module):
237 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
238 super().__init__()
The input has in_channels + out_channels
because we concatenate the output of the same resolution from the first half of the U-Net
241 self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
242 if has_attn:
243 self.attn = AttentionBlock(out_channels)
244 else:
245 self.attn = nn.Identity()
247 def forward(self, x: torch.Tensor, t: torch.Tensor):
248 x = self.res(x, t)
249 x = self.attn(x)
250 return x
It combines a ResidualBlock
, AttentionBlock
, followed by another ResidualBlock
. This block is applied at the lowest resolution of the U-Net.
253class MiddleBlock(nn.Module):
261 def __init__(self, n_channels: int, time_channels: int):
262 super().__init__()
263 self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
264 self.attn = AttentionBlock(n_channels)
265 self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
267 def forward(self, x: torch.Tensor, t: torch.Tensor):
268 x = self.res1(x, t)
269 x = self.attn(x)
270 x = self.res2(x, t)
271 return x
274class Upsample(nn.Module):
279 def __init__(self, n_channels):
280 super().__init__()
281 self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
283 def forward(self, x: torch.Tensor, t: torch.Tensor):
t
is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
.
286 _ = t
287 return self.conv(x)
290class Downsample(nn.Module):
295 def __init__(self, n_channels):
296 super().__init__()
297 self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
299 def forward(self, x: torch.Tensor, t: torch.Tensor):
t
is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock
.
302 _ = t
303 return self.conv(x)
306class UNet(nn.Module):
image_channels
is the number of channels in the image. for RGB. n_channels
is number of channels in the initial feature map that we transform the image into ch_mults
is the list of channel numbers at each resolution. The number of channels is ch_mults[i] * n_channels
is_attn
is a list of booleans that indicate whether to use attention at each resolution n_blocks
is the number of UpDownBlocks
at each resolution311 def __init__(self, image_channels: int = 3, n_channels: int = 64,
312 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
313 is_attn: Union[Tuple[bool, ...], List[bool]] = (False, False, True, True),
314 n_blocks: int = 2):
322 super().__init__()
Number of resolutions
325 n_resolutions = len(ch_mults)
Project image into feature map
328 self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
Time embedding layer. Time embedding has n_channels * 4
channels
331 self.time_emb = TimeEmbedding(n_channels * 4)
334 down = []
Number of channels
336 out_channels = in_channels = n_channels
For each resolution
338 for i in range(n_resolutions):
Number of output channels at this resolution
340 out_channels = in_channels * ch_mults[i]
Add n_blocks
342 for _ in range(n_blocks):
343 down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
344 in_channels = out_channels
Down sample at all resolutions except the last
346 if i < n_resolutions - 1:
347 down.append(Downsample(in_channels))
Combine the set of modules
350 self.down = nn.ModuleList(down)
Middle block
353 self.middle = MiddleBlock(out_channels, n_channels * 4, )
356 up = []
Number of channels
358 in_channels = out_channels
For each resolution
360 for i in reversed(range(n_resolutions)):
n_blocks
at the same resolution
362 out_channels = in_channels
363 for _ in range(n_blocks):
364 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
Final block to reduce the number of channels
366 out_channels = in_channels // ch_mults[i]
367 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
368 in_channels = out_channels
Up sample at all resolutions except last
370 if i > 0:
371 up.append(Upsample(in_channels))
Combine the set of modules
374 self.up = nn.ModuleList(up)
Final normalization and convolution layer
377 self.norm = nn.GroupNorm(8, n_channels)
378 self.act = Swish()
379 self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
x
has shape [batch_size, in_channels, height, width]
t
has shape [batch_size]
381 def forward(self, x: torch.Tensor, t: torch.Tensor):
Get time-step embeddings
388 t = self.time_emb(t)
Get image projection
391 x = self.image_proj(x)
h
will store outputs at each resolution for skip connection
394 h = [x]
First half of U-Net
396 for m in self.down:
397 x = m(x, t)
398 h.append(x)
Middle (bottom)
401 x = self.middle(x, t)
Second half of U-Net
404 for m in self.up:
405 if isinstance(m, Upsample):
406 x = m(x, t)
407 else:
Get the skip connection from first half of U-Net and concatenate
409 s = h.pop()
410 x = torch.cat((x, s), dim=1)
412 x = m(x, t)
Final normalization and convolution
415 return self.final(self.act(self.norm(x)))