这是 PayTorch 实现的纸质补丁是你所需要的吗?。

ConvMixer 与 MLP 混音器类似。MLP-Mixer 将空间维度和信道维度的混合分开,方法是跨空间维度应用 MLP,然后在通道维度上应用 MLP(空间 MLP 取代 ViT 注意力,频道 MLP 是 FFNViT)。
C@@onvMixer 使用卷积进行通道混合,使用深度卷积进行空间混合。由于它是卷积而不是整个空间的完整MLP,因此与 ViT 或 MLP 混音器相比,它只混合附近的批次。此外,MLP-Mixer 在每次混音时使用两层的 MLP,而 ConvMixer 为每次混音使用单个层。
本文建议移除通道混音中的残余连接(逐点卷积),并且在空间混合(深度卷积)上只有一个剩余连接。他们还使用批量归一化而不是图层规范化。
38import torch
39from torch import nn
40
41from labml_helpers.module import Module
42from labml_nn.utils import clone_module_list45class ConvMixerLayer(Module):d_model
是补丁嵌入中的通道数,kernel_size
是空间卷积内核的大小,54 def __init__(self, d_model: int, kernel_size: int):59 super().__init__()深度卷积是每个通道的单独卷积。我们使用卷积层来完成此操作,该卷积层的组数等于通道数。因此,每个频道都是它自己的组。
63 self.depth_wise_conv = nn.Conv2d(d_model, d_model,
64 kernel_size=kernel_size,
65 groups=d_model,
66 padding=(kernel_size - 1) // 2)深度卷积后激活
68 self.act1 = nn.GELU()深度卷积后的归一化
70 self.norm1 = nn.BatchNorm2d(d_model)逐点卷积是一种卷积。即补丁嵌入的线性变换
74 self.point_wise_conv = nn.Conv2d(d_model, d_model, kernel_size=1)逐点卷积后激活
76 self.act2 = nn.GELU()逐点卷积后的归一化
78 self.norm2 = nn.BatchNorm2d(d_model)80 def forward(self, x: torch.Tensor):对于围绕深度卷积的剩余连接
82 residual = x深度卷积、激活和归一化
85 x = self.depth_wise_conv(x)
86 x = self.act1(x)
87 x = self.norm1(x)添加剩余连接
90 x += residual逐点卷积、激活和归一化
93 x = self.point_wise_conv(x)
94 x = self.act2(x)
95 x = self.norm2(x)98 return x101class PatchEmbeddings(Module):d_model
是补丁嵌入中的通道数patch_size
是补丁的大小,in_channels
是输入图像中的通道数(rgb 为 3)110 def __init__(self, d_model: int, patch_size: int, in_channels: int):116 super().__init__()我们创建一个卷积层,其内核大小和步长等于补丁大小。这相当于将图像分割成色块并在每个面片上进行线性变换。
121 self.conv = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)激活功能
123 self.act = nn.GELU()批量标准化
125 self.norm = nn.BatchNorm2d(d_model)x
是形状的输入图像[batch_size, channels, height, width]
127 def forward(self, x: torch.Tensor):应用卷积层
132 x = self.conv(x)激活和规范化
134 x = self.act(x)
135 x = self.norm(x)138 return x141class ClassificationHead(Module):d_model
是补丁嵌入中的通道数,n_classes
是分类任务中的类数151 def __init__(self, d_model: int, n_classes: int):156 super().__init__()平均池
158 self.pool = nn.AdaptiveAvgPool2d((1, 1))线性层
160 self.linear = nn.Linear(d_model, n_classes)162 def forward(self, x: torch.Tensor):平均汇集
164 x = self.pool(x)得到嵌入,x
会有形状[batch_size, d_model, 1, 1]
166 x = x[:, :, 0, 0]线性层
168 x = self.linear(x)171 return x174class ConvMixer(Module):conv_mixer_layer
是单个 C onvMixer 层的副本。我们制作它的副本来制作 ConvMixern_layers
。n_layers
是 ConvMixer 层(或深度)的数量。patch_emb
是补丁嵌入层。classification
是分类头。181 def __init__(self, conv_mixer_layer: ConvMixerLayer, n_layers: int,
182 patch_emb: PatchEmbeddings,
183 classification: ClassificationHead):191 super().__init__()补丁嵌入
193 self.patch_emb = patch_emb分类主管
195 self.classification = classification制作 C onvMixer 图层的副本
197 self.conv_mixer_layers = clone_module_list(conv_mixer_layer, n_layers)x
是形状的输入图像[batch_size, channels, height, width]
199 def forward(self, x: torch.Tensor):获取补丁嵌入。这给出了形状的张量[batch_size, d_model, height / patch_size, width / patch_size]
。
204 x = self.patch_emb(x)穿过 ConvMixer 图层
207 for layer in self.conv_mixer_layers:
208 x = layer(x)分类头,获取日志
211 x = self.classification(x)214 return x