用于去噪扩散概率模型 (DDPM) 的 U-Net 模型

这是一个基于 U-Net 的模型,用于预测噪声

U-Net 是从模型图中的 U 形中获取它的名字。它通过逐步降低(减半)要素图分辨率,然后提高分辨率来处理给定的图像。每种分辨率都有直通连接。

U-Net diagram from paper

此实现包含对原始 U-Net(残差块、多头注意)的大量修改,还添加了时间步长嵌入

24import math
25from typing import Optional, Tuple, Union, List
26
27import torch
28from torch import nn
29
30from labml_helpers.module import Module

Swish 激活功能

33class Swish(Module):
40    def forward(self, x):
41        return x * torch.sigmoid(x)

嵌入用于

44class TimeEmbedding(nn.Module):
  • n_channels 是嵌入中的维数
49    def __init__(self, n_channels: int):
53        super().__init__()
54        self.n_channels = n_channels

第一个线性层

56        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)

激活

58        self.act = Swish()

第二个线性层

60        self.lin2 = nn.Linear(self.n_channels, self.n_channels)
62    def forward(self, t: torch.Tensor):

创建与变压器相同的正弦位置嵌入

在哪half_dim

72        half_dim = self.n_channels // 8
73        emb = math.log(10_000) / (half_dim - 1)
74        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
75        emb = t[:, None] * emb[None, :]
76        emb = torch.cat((emb.sin(), emb.cos()), dim=1)

使用 MLP 进行转型

79        emb = self.act(self.lin1(emb))
80        emb = self.lin2(emb)

83        return emb

剩余方块

残差块具有两个具有组归一化的卷积层。每个分辨率都使用两个残差块进行处理。

86class ResidualBlock(Module):
  • in_channels 是输入声道的数量
  • out_channels 是输入声道的数量
  • time_channels 是 time step () 嵌入中的数字通道
  • n_groups 是组归一化的组
94    def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
101        super().__init__()

组归一化和第一个卷积层

103        self.norm1 = nn.GroupNorm(n_groups, in_channels)
104        self.act1 = Swish()
105        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

组归一化和第二个卷积层

108        self.norm2 = nn.GroupNorm(n_groups, out_channels)
109        self.act2 = Swish()
110        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))

如果输入通道的数量不等于输出通道的数量,我们必须投影快捷方式连接

114        if in_channels != out_channels:
115            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
116        else:
117            self.shortcut = nn.Identity()

用于时间嵌入的线性层

120        self.time_emb = nn.Linear(time_channels, out_channels)
  • x 有形状[batch_size, in_channels, height, width]
  • t 有形状[batch_size, time_channels]
122    def forward(self, x: torch.Tensor, t: torch.Tensor):

第一个卷积层

128        h = self.conv1(self.act1(self.norm1(x)))

添加时间嵌入

130        h += self.time_emb(t)[:, :, None, None]

第二个卷积层

132        h = self.conv2(self.act2(self.norm2(h)))

添加快捷方式连接并返回

135        return h + self.shortcut(x)

注意力块

这类似于变压器多头的关注

138class AttentionBlock(Module):
  • n_channels 是输入中的声道数
  • n_heads 是多头关注中的头部数量
  • d_k 是每个头部的尺寸数
  • n_groups 是组归一化的组
145    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
152        super().__init__()

默认d_k

155        if d_k is None:
156            d_k = n_channels

归一化层

158        self.norm = nn.GroupNorm(n_groups, n_channels)

查询、键和值的投影

160        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)

用于最终变换的线性层

162        self.output = nn.Linear(n_heads * d_k, n_channels)

缩放点产品注意力

164        self.scale = d_k ** -0.5

166        self.n_heads = n_heads
167        self.d_k = d_k
  • x 有形状[batch_size, in_channels, height, width]
  • t 有形状[batch_size, time_channels]
169    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):

t 未使用,但它保留在参数中,因为要与注意层函数签名匹配ResidualBlock

176        _ = t

塑造身材

178        batch_size, n_channels, height, width = x.shape

x 成形状[batch_size, seq, n_channels]

180        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)

获取查询、键和值(串联)并将其调整为[batch_size, seq, n_heads, 3 * d_k]

182        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)

拆分查询、键和值。他们每个人都会有形状[batch_size, seq, n_heads, d_k]

184        q, k, v = torch.chunk(qkv, 3, dim=-1)

计算缩放的点积

186        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale

顺序维度上的 Softmax

188        attn = attn.softmax(dim=1)

乘以值

190        res = torch.einsum('bijh,bjhd->bihd', attn, v)

重塑为[batch_size, seq, n_heads * d_k]

192        res = res.view(batch_size, -1, self.n_heads * self.d_k)

变换为[batch_size, seq, n_channels]

194        res = self.output(res)

添加跳过连接

197        res += x

改成形状[batch_size, in_channels, height, width]

200        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)

203        return res

向下方块

这结合了ResidualBlockAttentionBlock .这些在U-Net的前半部分以每种分辨率使用。

206class DownBlock(Module):
213    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
214        super().__init__()
215        self.res = ResidualBlock(in_channels, out_channels, time_channels)
216        if has_attn:
217            self.attn = AttentionBlock(out_channels)
218        else:
219            self.attn = nn.Identity()
221    def forward(self, x: torch.Tensor, t: torch.Tensor):
222        x = self.res(x, t)
223        x = self.attn(x)
224        return x

向上方块

这结合了ResidualBlockAttentionBlock .这些在U-Net的后半部分以每种分辨率使用。

227class UpBlock(Module):
234    def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
235        super().__init__()

输入之in_channels + out_channels 所以有,是因为我们将 U-Net 前半部分相同分辨率的输出连接起来

238        self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
239        if has_attn:
240            self.attn = AttentionBlock(out_channels)
241        else:
242            self.attn = nn.Identity()
244    def forward(self, x: torch.Tensor, t: torch.Tensor):
245        x = self.res(x, t)
246        x = self.attn(x)
247        return x

中间方块

它结合了ResidualBlockAttentionBlock 、后跟另一个ResidualBlock 。此块应用于 U-Net 的最低分辨率。

250class MiddleBlock(Module):
258    def __init__(self, n_channels: int, time_channels: int):
259        super().__init__()
260        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
261        self.attn = AttentionBlock(n_channels)
262        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
264    def forward(self, x: torch.Tensor, t: torch.Tensor):
265        x = self.res1(x, t)
266        x = self.attn(x)
267        x = self.res2(x, t)
268        return x

按比例放大要素地图

271class Upsample(nn.Module):
276    def __init__(self, n_channels):
277        super().__init__()
278        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
280    def forward(self, x: torch.Tensor, t: torch.Tensor):

t 未使用,但它保留在参数中,因为要与注意层函数签名匹配ResidualBlock

283        _ = t
284        return self.conv(x)

按比例缩小要素地图

287class Downsample(nn.Module):
292    def __init__(self, n_channels):
293        super().__init__()
294        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
296    def forward(self, x: torch.Tensor, t: torch.Tensor):

t 未使用,但它保留在参数中,因为要与注意层函数签名匹配ResidualBlock

299        _ = t
300        return self.conv(x)

U-Net

303class UNet(Module):
  • image_channels 是图像中的通道数。对于 RGB。
  • n_channels 是初始特征图中我们将图像转换为的通道数
  • ch_mults 是每种分辨率下的通道编号列表。频道的数量是ch_mults[i] * n_channels
  • is_attn 是一个布尔值列表,用于指示是否在每个分辨率下使用注意力
  • n_blocks 是每种分辨UpDownBlocks 率的数字
308    def __init__(self, image_channels: int = 3, n_channels: int = 64,
309                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
310                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
311                 n_blocks: int = 2):
319        super().__init__()

分辨率数量

322        n_resolutions = len(ch_mults)

将图像投影到要素地图中

325        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))

时间嵌入层。时间嵌入有n_channels * 4 频道

328        self.time_emb = TimeEmbedding(n_channels * 4)

U-Net 的前半部分-分辨率降低

331        down = []

频道数量

333        out_channels = in_channels = n_channels

对于每种分辨率

335        for i in range(n_resolutions):

此分辨率下的输出声道数

337            out_channels = in_channels * ch_mults[i]

添加n_blocks

339            for _ in range(n_blocks):
340                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
341                in_channels = out_channels

除最后一个分辨率之外的所有分辨率都向下采样

343            if i < n_resolutions - 1:
344                down.append(Downsample(in_channels))

组合这组模块

347        self.down = nn.ModuleList(down)

中间方块

350        self.middle = MiddleBlock(out_channels, n_channels * 4, )

U-Net 的后半部分-提高分辨率

353        up = []

频道数量

355        in_channels = out_channels

对于每种分辨率

357        for i in reversed(range(n_resolutions)):

n_blocks 以相同的分辨率

359            out_channels = in_channels
360            for _ in range(n_blocks):
361                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))

减少信道数量的最终区块

363            out_channels = in_channels // ch_mults[i]
364            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
365            in_channels = out_channels

除最后一个以外的所有分辨率向上采样

367            if i > 0:
368                up.append(Upsample(in_channels))

组合这组模块

371        self.up = nn.ModuleList(up)

最终归一化和卷积层

374        self.norm = nn.GroupNorm(8, n_channels)
375        self.act = Swish()
376        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
  • x 有形状[batch_size, in_channels, height, width]
  • t 有形状[batch_size]
378    def forward(self, x: torch.Tensor, t: torch.Tensor):

获取时间步长嵌入

385        t = self.time_emb(t)

获取图像投影

388        x = self.image_proj(x)

h 将以每种分辨率存储输出以进行跳过连接

391        h = [x]

U-Net 的上半年

393        for m in self.down:
394            x = m(x, t)
395            h.append(x)

中间(底部)

398        x = self.middle(x, t)

U-Net 的下半场

401        for m in self.up:
402            if isinstance(m, Upsample):
403                x = m(x, t)
404            else:

从 U-Net 的前半部分获取跳过连接并连接

406                s = h.pop()
407                x = torch.cat((x, s), dim=1)

409                x = m(x, t)

最终归一化和卷积

412        return self.final(self.act(self.norm(x)))