This is a PyTorch implementation of the paper An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale.
Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
Patch embeddings are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token [CLS].
The encoding on the [CLS] token is used to classify the image with an MLP.
When feeding the transformer with the patches, learned positional embeddings are added to the patch embeddings, because the patch embeddings do not have any information about where that patch is from. The positional embeddings are a set of vectors for each patch location that get trained with gradient descent along with other parameters.
ViTs perform well when they are pre-trained on large datasets. The paper suggests pre-training them with an MLP classification head and then using a single linear layer when fine-tuning. The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. They also use higher resolution images during inference while keeping the patch size the same. The positional embeddings for new patch locations are calculated by interpolating learning positional embeddings.
Here’s an experiment that trains ViT on CIFAR-10. This doesn’t do very well because it’s trained on a small dataset. It’s a simple experiment that anyone can run and play with ViTs.
45import torch
46from torch import nn
47
48from labml_helpers.module import Module
49from labml_nn.transformers import TransformerLayer
50from labml_nn.utils import clone_module_listThe paper splits the image into patches of equal size and do a linear transformation on the flattened pixels for each patch.
We implement the same thing through a convolution layer, because it’s simpler to implement.
53class PatchEmbeddings(Module):d_model is the transformer embeddings sizepatch_size is the size of the patchin_channels is the number of channels in the input image (3 for rgb)65 def __init__(self, d_model: int, patch_size: int, in_channels: int):71 super().__init__()We create a convolution layer with a kernel size and and stride length equal to patch size. This is equivalent to splitting the image into patches and doing a linear transformation on each patch.
76 self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)x is the input image of shape [batch_size, channels, height, width]78 def forward(self, x: torch.Tensor):Apply convolution layer
83 x = self.conv(x)Get the shape.
85 bs, c, h, w = x.shapeRearrange to shape [patches, batch_size, d_model]
87 x = x.permute(2, 3, 0, 1)
88 x = x.view(h * w, bs, c)Return the patch embeddings
91 return xThis adds learned positional embeddings to patch embeddings.
94class LearnedPositionalEmbeddings(Module):d_model is the transformer embeddings sizemax_len is the maximum number of patches103 def __init__(self, d_model: int, max_len: int = 5_000):108 super().__init__()Positional embeddings for each location
110 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)x is the patch embeddings of shape [patches, batch_size, d_model]112 def forward(self, x: torch.Tensor):Get the positional embeddings for the given patches
117 pe = self.positional_encodings[x.shape[0]]Add to patch embeddings and return
119 return x + peThis is the two layer MLP head to classify the image based on [CLS] token embedding.
122class ClassificationHead(Module):d_model is the transformer embedding sizen_hidden is the size of the hidden layern_classes is the number of classes in the classification task130 def __init__(self, d_model: int, n_hidden: int, n_classes: int):136 super().__init__()First layer
138 self.linear1 = nn.Linear(d_model, n_hidden)Activation
140 self.act = nn.ReLU()Second layer
142 self.linear2 = nn.Linear(n_hidden, n_classes)x is the transformer encoding for [CLS] token144 def forward(self, x: torch.Tensor):First layer and activation
149 x = self.act(self.linear1(x))Second layer
151 x = self.linear2(x)154 return xThis combines the patch embeddings, positional embeddings, transformer and the classification head.
157class VisionTransformer(Module):transformer_layer is a copy of a single transformer layer.
We make copies of it to make the transformer with n_layers.n_layers is the number of transformer layers.patch_emb is the patch embeddings layer.pos_emb is the positional embeddings layer.classification is the classification head.165 def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
166 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
167 classification: ClassificationHead):176 super().__init__()Patch embeddings
178 self.patch_emb = patch_emb
179 self.pos_emb = pos_embClassification head
181 self.classification = classificationMake copies of the transformer layer
183 self.transformer_layers = clone_module_list(transformer_layer, n_layers)[CLS] token embedding
186 self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)Final normalization layer
188 self.ln = nn.LayerNorm([transformer_layer.size])x is the input image of shape [batch_size, channels, height, width]190 def forward(self, x: torch.Tensor):Get patch embeddings. This gives a tensor of shape [patches, batch_size, d_model]
195 x = self.patch_emb(x)Add positional embeddings
197 x = self.pos_emb(x)Concatenate the [CLS] token embeddings before feeding the transformer
199 cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
200 x = torch.cat([cls_token_emb, x])Pass through transformer layers with no attention masking
203 for layer in self.transformer_layers:
204 x = layer(x=x, mask=None)Get the transformer output of the [CLS] token (which is the first in the sequence).
207 x = x[0]Layer normalization
210 x = self.ln(x)Classification head, to get logits
213 x = self.classification(x)216 return x