这是一个基于 U-Net 的模型,用于预测噪声。
U-Net 是从模型图中的 U 形中获取它的名字。它通过逐步降低(减半)要素图分辨率,然后提高分辨率来处理给定的图像。每种分辨率都有直通连接。

此实现包含对原始 U-Net(残差块、多头注意)的大量修改,还添加了时间步长嵌入。
24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module33class Swish(Module):40 def forward(self, x):
41 return x * torch.sigmoid(x)44class TimeEmbedding(nn.Module):n_channels
是嵌入中的维数49 def __init__(self, n_channels: int):53 super().__init__()
54 self.n_channels = n_channels第一个线性层
56 self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)激活
58 self.act = Swish()第二个线性层
60 self.lin2 = nn.Linear(self.n_channels, self.n_channels)62 def forward(self, t: torch.Tensor):72 half_dim = self.n_channels // 8
73 emb = math.log(10_000) / (half_dim - 1)
74 emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75 emb = t[:, None] * emb[None, :]
76 emb = torch.cat((emb.sin(), emb.cos()), dim=1)使用 MLP 进行转型
79 emb = self.act(self.lin1(emb))
80 emb = self.lin2(emb)83 return emb86class ResidualBlock(Module):in_channels
是输入声道的数量out_channels
是输入声道的数量time_channels
是 time step () 嵌入中的数字通道n_groups
是组归一化的组数94 def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):101 super().__init__()组归一化和第一个卷积层
103 self.norm1 = nn.GroupNorm(n_groups, in_channels)
104 self.act1 = Swish()
105 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))组归一化和第二个卷积层
108 self.norm2 = nn.GroupNorm(n_groups, out_channels)
109 self.act2 = Swish()
110 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))如果输入通道的数量不等于输出通道的数量,我们必须投影快捷方式连接
114 if in_channels != out_channels:
115 self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
116 else:
117 self.shortcut = nn.Identity()用于时间嵌入的线性层
120 self.time_emb = nn.Linear(time_channels, out_channels)x
有形状[batch_size, in_channels, height, width]
t
有形状[batch_size, time_channels]
122 def forward(self, x: torch.Tensor, t: torch.Tensor):第一个卷积层
128 h = self.conv1(self.act1(self.norm1(x)))添加时间嵌入
130 h += self.time_emb(t)[:, :, None, None]第二个卷积层
132 h = self.conv2(self.act2(self.norm2(h)))添加快捷方式连接并返回
135 return h + self.shortcut(x)145 def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):152 super().__init__()默认d_k
155 if d_k is None:
156 d_k = n_channels归一化层
158 self.norm = nn.GroupNorm(n_groups, n_channels)查询、键和值的投影
160 self.projection = nn.Linear(n_channels, n_heads * d_k * 3)用于最终变换的线性层
162 self.output = nn.Linear(n_heads * d_k, n_channels)缩放点产品注意力
164 self.scale = d_k ** -0.5166 self.n_heads = n_heads
167 self.d_k = d_kx
有形状[batch_size, in_channels, height, width]
t
有形状[batch_size, time_channels]
169 def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):t
未使用,但它保留在参数中,因为要与注意层函数签名匹配ResidualBlock
。
176 _ = t塑造身材
178 batch_size, n_channels, height, width = x.shape改x
成形状[batch_size, seq, n_channels]
180 x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)获取查询、键和值(串联)并将其调整为[batch_size, seq, n_heads, 3 * d_k]
182 qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)拆分查询、键和值。他们每个人都会有形状[batch_size, seq, n_heads, d_k]
184 q, k, v = torch.chunk(qkv, 3, dim=-1)计算缩放的点积
186 attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale顺序维度上的 Softmax
188 attn = attn.softmax(dim=1)乘以值
190 res = torch.einsum('bijh,bjhd->bihd', attn, v)重塑为[batch_size, seq, n_heads * d_k]
192 res = res.view(batch_size, -1, self.n_heads * self.d_k)变换为[batch_size, seq, n_channels]
194 res = self.output(res)添加跳过连接
197 res += x改成形状[batch_size, in_channels, height, width]
200 res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)203 return res206class DownBlock(Module):213 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
214 super().__init__()
215 self.res = ResidualBlock(in_channels, out_channels, time_channels)
216 if has_attn:
217 self.attn = AttentionBlock(out_channels)
218 else:
219 self.attn = nn.Identity()221 def forward(self, x: torch.Tensor, t: torch.Tensor):
222 x = self.res(x, t)
223 x = self.attn(x)
224 return x227class UpBlock(Module):234 def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
235 super().__init__()输入之in_channels + out_channels
所以有,是因为我们将 U-Net 前半部分相同分辨率的输出连接起来
238 self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
239 if has_attn:
240 self.attn = AttentionBlock(out_channels)
241 else:
242 self.attn = nn.Identity()244 def forward(self, x: torch.Tensor, t: torch.Tensor):
245 x = self.res(x, t)
246 x = self.attn(x)
247 return x250class MiddleBlock(Module):258 def __init__(self, n_channels: int, time_channels: int):
259 super().__init__()
260 self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
261 self.attn = AttentionBlock(n_channels)
262 self.res2 = ResidualBlock(n_channels, n_channels, time_channels)264 def forward(self, x: torch.Tensor, t: torch.Tensor):
265 x = self.res1(x, t)
266 x = self.attn(x)
267 x = self.res2(x, t)
268 return x271class Upsample(nn.Module):276 def __init__(self, n_channels):
277 super().__init__()
278 self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))280 def forward(self, x: torch.Tensor, t: torch.Tensor):t
未使用,但它保留在参数中,因为要与注意层函数签名匹配ResidualBlock
。
283 _ = t
284 return self.conv(x)287class Downsample(nn.Module):292 def __init__(self, n_channels):
293 super().__init__()
294 self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))296 def forward(self, x: torch.Tensor, t: torch.Tensor):t
未使用,但它保留在参数中,因为要与注意层函数签名匹配ResidualBlock
。
299 _ = t
300 return self.conv(x)303class UNet(Module):image_channels
是图像中的通道数。对于 RGB。n_channels
是初始特征图中我们将图像转换为的通道数ch_mults
是每种分辨率下的通道编号列表。频道的数量是ch_mults[i] * n_channels
is_attn
是一个布尔值列表,用于指示是否在每个分辨率下使用注意力n_blocks
是每种分辨UpDownBlocks
率的数字308 def __init__(self, image_channels: int = 3, n_channels: int = 64,
309 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
310 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
311 n_blocks: int = 2):319 super().__init__()分辨率数量
322 n_resolutions = len(ch_mults)将图像投影到要素地图中
325 self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))时间嵌入层。时间嵌入有n_channels * 4
频道
328 self.time_emb = TimeEmbedding(n_channels * 4)331 down = []频道数量
333 out_channels = in_channels = n_channels对于每种分辨率
335 for i in range(n_resolutions):此分辨率下的输出声道数
337 out_channels = in_channels * ch_mults[i]添加n_blocks
339 for _ in range(n_blocks):
340 down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
341 in_channels = out_channels除最后一个分辨率之外的所有分辨率都向下采样
343 if i < n_resolutions - 1:
344 down.append(Downsample(in_channels))组合这组模块
347 self.down = nn.ModuleList(down)中间方块
350 self.middle = MiddleBlock(out_channels, n_channels * 4, )353 up = []频道数量
355 in_channels = out_channels对于每种分辨率
357 for i in reversed(range(n_resolutions)):n_blocks
以相同的分辨率
359 out_channels = in_channels
360 for _ in range(n_blocks):
361 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))减少信道数量的最终区块
363 out_channels = in_channels // ch_mults[i]
364 up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
365 in_channels = out_channels除最后一个以外的所有分辨率向上采样
367 if i > 0:
368 up.append(Upsample(in_channels))组合这组模块
371 self.up = nn.ModuleList(up)最终归一化和卷积层
374 self.norm = nn.GroupNorm(8, n_channels)
375 self.act = Swish()
376 self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))x
有形状[batch_size, in_channels, height, width]
t
有形状[batch_size]
378 def forward(self, x: torch.Tensor, t: torch.Tensor):获取时间步长嵌入
385 t = self.time_emb(t)获取图像投影
388 x = self.image_proj(x)h
将以每种分辨率存储输出以进行跳过连接
391 h = [x]U-Net 的上半年
393 for m in self.down:
394 x = m(x, t)
395 h.append(x)中间(底部)
398 x = self.middle(x, t)U-Net 的下半场
401 for m in self.up:
402 if isinstance(m, Upsample):
403 x = m(x, t)
404 else:从 U-Net 的前半部分获取跳过连接并连接
406 s = h.pop()
407 x = torch.cat((x, s), dim=1)409 x = m(x, t)最终归一化和卷积
412 return self.final(self.act(self.norm(x)))