补丁是你所需要的吗?(convMixer)

这是 PayTorch 实现的纸质补丁是你所需要的吗?

ConvMixer diagram from the paper

ConvMixer 与 MLP 混音器类似。MLP-Mixer 将空间维度和信道维度的混合分开,方法是跨空间维度应用 MLP,然后在通道维度上应用 MLP(空间 MLP 取代 ViT 注意力,频道 MLP 是 FFNViT)。

C@@

onvMixer 使用卷积进行通道混合,使用深度卷积进行空间混合。由于它是卷积而不是整个空间的完整MLP,因此与 ViT 或 MLP 混音器相比,它只混合附近的批次。此外,MLP-Mixer 在每次混音时使用两层的 MLP,而 ConvMixer 为每次混音使用单个层。

本文建议移除通道混音中的残余连接(逐点卷积),并且在空间混合(深度卷积)上只有一个剩余连接。他们还使用批量归一化而不是图层规范化

这是一个在 CIFAR-10 上训练 ConvMixer 的实验

View Run

38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.utils import clone_module_list

混音器层

这是单个 ConvMixer 层。该模型将有一系列这样的。

45class ConvMixerLayer(Module):
  • d_model 是补丁嵌入中的通道数,
  • kernel_size 是空间卷积内核的大小,
54    def __init__(self, d_model: int, kernel_size: int):
59        super().__init__()

深度卷积是每个通道的单独卷积。我们使用卷积层来完成此操作,该卷积层的组数等于通道数。因此,每个频道都是它自己的组。

63        self.depth_wise_conv = nn.Conv2d(d_model, d_model,
64                                         kernel_size=kernel_size,
65                                         groups=d_model,
66                                         padding=(kernel_size - 1) // 2)

深度卷积后激活

68        self.act1 = nn.GELU()

深度卷积后的归一化

70        self.norm1 = nn.BatchNorm2d(d_model)

逐点卷积是一种卷积。即补丁嵌入的线性变换

74        self.point_wise_conv = nn.Conv2d(d_model, d_model, kernel_size=1)

逐点卷积后激活

76        self.act2 = nn.GELU()

逐点卷积后的归一化

78        self.norm2 = nn.BatchNorm2d(d_model)
80    def forward(self, x: torch.Tensor):

对于围绕深度卷积的剩余连接

82        residual = x

深度卷积、激活和归一化

85        x = self.depth_wise_conv(x)
86        x = self.act1(x)
87        x = self.norm1(x)

添加剩余连接

90        x += residual

逐点卷积、激活和归一化

93        x = self.point_wise_conv(x)
94        x = self.act2(x)
95        x = self.norm2(x)

98        return x

获取补丁嵌入

这会将图像拆分为大小的补丁,并为每个补丁提供嵌入。

101class PatchEmbeddings(Module):
  • d_model 是补丁嵌入中的通道数
  • patch_size 是补丁的大小,
  • in_channels 是输入图像中的通道数(rgb 为 3)
110    def __init__(self, d_model: int, patch_size: int, in_channels: int):
116        super().__init__()

我们创建一个卷积层,其内核大小和步长等于补丁大小。这相当于将图像分割成色块并在每个面片上进行线性变换。

121        self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

激活功能

123        self.act = nn.GELU()

批量标准化

125        self.norm = nn.BatchNorm2d(d_model)
  • x 是形状的输入图像[batch_size, channels, height, width]
127    def forward(self, x: torch.Tensor):

应用卷积层

132        x = self.conv(x)

激活和规范化

134        x = self.act(x)
135        x = self.norm(x)

138        return x

分类主管

它们进行平均池(取所有补丁嵌入的均值)和最终的线性变换来预测影像类的对数概率。

141class ClassificationHead(Module):
  • d_model 是补丁嵌入中的通道数,
  • n_classes 是分类任务中的类数
151    def __init__(self, d_model: int, n_classes: int):
156        super().__init__()

平均池

158        self.pool = nn.AdaptiveAvgPool2d((1, 1))

线性层

160        self.linear = nn.Linear(d_model, n_classes)
162    def forward(self, x: torch.Tensor):

平均汇集

164        x = self.pool(x)

得到嵌入,x 会有形状[batch_size, d_model, 1, 1]

166        x = x[:, :, 0, 0]

线性层

168        x = self.linear(x)

171        return x

混音器

它结合了补丁嵌入块、许多 ConvMixer 层和一个分类头。

174class ConvMixer(Module):
  • conv_mixer_layer 是单个 C onvMixer 层的副本。我们制作它的副本来制作 ConvMixern_layers
  • n_layers 是 ConvMixer 层(或深度)的数量
  • patch_emb补丁嵌入层
  • classification分类头
181    def __init__(self, conv_mixer_layer: ConvMixerLayer, n_layers: int,
182                 patch_emb: PatchEmbeddings,
183                 classification: ClassificationHead):
191        super().__init__()

补丁嵌入

193        self.patch_emb = patch_emb

分类主管

195        self.classification = classification

制作 C onvMixer 图层的副本

197        self.conv_mixer_layers = clone_module_list(conv_mixer_layer, n_layers)
  • x 是形状的输入图像[batch_size, channels, height, width]
199    def forward(self, x: torch.Tensor):

获取补丁嵌入。这给出了形状的张量[batch_size, d_model, height / patch_size, width / patch_size]

204        x = self.patch_emb(x)
207        for layer in self.conv_mixer_layers:
208            x = layer(x)

分类头,获取日志

211        x = self.classification(x)

214        return x