diff --git a/docs/transformers/index.html b/docs/transformers/index.html
index 1c2d646b..52ba8c9a 100644
--- a/docs/transformers/index.html
+++ b/docs/transformers/index.html
@@ -117,12 +117,15 @@ It does single GPU training but we implement the concept of switching as describ
Vision transformer applies a pure transformer to images
+without any convolution layers.
+They split the image into patches and apply a transformer on patch embeddings.
+Patch embeddings are generated by applying a simple linear transformation
+to the flattened pixel values of the patch.
+Then a standard transformer encoder is fed with the patch embeddings, along with a
+classification token [CLS].
+The encoding on the [CLS] token is used to classify the image with an MLP.
+
When feeding the transformer with the patches, learned positional embeddings are
+added to the patch embeddings, because the patch embeddings do not have any information
+about where that patch is from.
+The positional embeddings are a set of vectors for each patch location that get trained
+with gradient descent along with other parameters.
+
ViTs perform well when they are pre-trained on large datasets.
+The paper suggests pre-training them with an MLP classification head and
+then using a single linear layer when fine-tuning.
+The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
+They also use higher resolution images during inference while keeping the
+patch size the same.
+The positional embeddings for new patch locations are calculated by interpolating
+learning positional embeddings.
+
Here’s an experiment that trains ViT on CIFAR-10.
+This doesn’t do very well because it’s trained on a small dataset.
+It’s a simple experiment that anyone can run and play with ViTs.
We create a convolution layer with a kernel size and and stride length equal to patch size.
+This is equivalent to splitting the image into patches and doing a linear
+transformation on each patch.
Vision transformer applies a pure transformer to images
+without any convolution layers.
+They split the image into patches and apply a transformer on patch embeddings.
+Patch embeddings are generated by applying a simple linear transformation
+to the flattened pixel values of the patch.
+Then a standard transformer encoder is fed with the patch embeddings, along with a
+classification token [CLS].
+The encoding on the [CLS] token is used to classify the image with an MLP.
+
When feeding the transformer with the patches, learned positional embeddings are
+added to the patch embeddings, because the patch embeddings do not have any information
+about where that patch is from.
+The positional embeddings are a set of vectors for each patch location that get trained
+with gradient descent along with other parameters.
+
ViTs perform well when they are pre-trained on large datasets.
+The paper suggests pre-training them with an MLP classification head and
+then using a single linear layer when fine-tuning.
+The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
+They also use higher resolution images during inference while keeping the
+patch size the same.
+The positional embeddings for new patch locations are calculated by interpolating
+learning positional embeddings.
+
Here’s an experiment that trains ViT on CIFAR-10.
+This doesn’t do very well because it’s trained on a small dataset.
+It’s a simple experiment that anyone can run and play with ViTs.
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py
index f719fb02..d7202226 100644
--- a/labml_nn/__init__.py
+++ b/labml_nn/__init__.py
@@ -31,6 +31,7 @@ implementations.
* [Masked Language Model](transformers/mlm/index.html)
* [MLP-Mixer: An all-MLP Architecture for Vision](transformers/mlp_mixer/index.html)
* [Pay Attention to MLPs (gMLP)](transformers/gmlp/index.html)
+* [Vision Transformer (ViT)](transformers/vit/index.html)
#### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html)
diff --git a/labml_nn/transformers/__init__.py b/labml_nn/transformers/__init__.py
index d58ed81d..37c2ba99 100644
--- a/labml_nn/transformers/__init__.py
+++ b/labml_nn/transformers/__init__.py
@@ -82,6 +82,11 @@ This is an implementation of the paper
This is an implementation of the paper
[Pay Attention to MLPs](https://papers.labml.ai/paper/2105.08050).
+
+## [Vision Transformer (ViT)](vit/index.html)
+
+This is an implementation of the paper
+[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
"""
from .configs import TransformerConfigs
diff --git a/labml_nn/transformers/vit/__init__.py b/labml_nn/transformers/vit/__init__.py
index d827192e..d45b8c05 100644
--- a/labml_nn/transformers/vit/__init__.py
+++ b/labml_nn/transformers/vit/__init__.py
@@ -1,3 +1,47 @@
+"""
+---
+title: Vision Transformer (ViT)
+summary: >
+ A PyTorch implementation/tutorial of the paper
+ "An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale"
+---
+
+# Vision Transformer (ViT)
+
+This is a [PyTorch](https://pytorch.org) implementation of the paper
+[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
+
+Vision transformer applies a pure transformer to images
+without any convolution layers.
+They split the image into patches and apply a transformer on patch embeddings.
+[Patch embeddings](#PathEmbeddings) are generated by applying a simple linear transformation
+to the flattened pixel values of the patch.
+Then a standard transformer encoder is fed with the patch embeddings, along with a
+classification token `[CLS]`.
+The encoding on the `[CLS]` token is used to classify the image with an MLP.
+
+When feeding the transformer with the patches, learned positional embeddings are
+added to the patch embeddings, because the patch embeddings do not have any information
+about where that patch is from.
+The positional embeddings are a set of vectors for each patch location that get trained
+with gradient descent along with other parameters.
+
+ViTs perform well when they are pre-trained on large datasets.
+The paper suggests pre-training them with an MLP classification head and
+then using a single linear layer when fine-tuning.
+The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
+They also use higher resolution images during inference while keeping the
+patch size the same.
+The positional embeddings for new patch locations are calculated by interpolating
+learning positional embeddings.
+
+Here's [an experiment](experiment.html) that trains ViT on CIFAR-10.
+This doesn't do very well because it's trained on a small dataset.
+It's a simple experiment that anyone can run and play with ViTs.
+
+[](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
+"""
+
import torch
from torch import nn
@@ -9,24 +53,41 @@ from labml_nn.utils import clone_module_list
class PatchEmbeddings(Module):
"""
- ## Embed patches
+ ## Get patch embeddings
+
+ The paper splits the image into patches of equal size and do a linear transformation
+ on the flattened pixels for each patch.
+
+ We implement the same thing through a convolution layer, because it's simpler to implement.
"""
def __init__(self, d_model: int, patch_size: int, in_channels: int):
+ """
+ * `d_model` is the transformer embeddings size
+ * `patch_size` is the size of the patch
+ * `in_channels` is the number of channels in the input image (3 for rgb)
+ """
super().__init__()
- self.patch_size = patch_size
+
+ # We create a convolution layer with a kernel size and and stride length equal to patch size.
+ # This is equivalent to splitting the image into patches and doing a linear
+ # transformation on each patch.
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
def __call__(self, x: torch.Tensor):
"""
- x has shape `[batch_size, channels, height, width]`
+ * `x` is the input image of shape `[batch_size, channels, height, width]`
"""
+ # Apply convolution layer
x = self.conv(x)
+ # Get the shape.
bs, c, h, w = x.shape
+ # Rearrange to shape `[patches, batch_size, d_model]`
x = x.permute(2, 3, 0, 1)
x = x.view(h * w, bs, c)
+ # Return the patch embeddings
return x
@@ -35,56 +96,121 @@ class LearnedPositionalEmbeddings(Module):
## Add parameterized positional encodings
+
+ This adds learned positional embeddings to patch embeddings.
"""
def __init__(self, d_model: int, max_len: int = 5_000):
+ """
+ * `d_model` is the transformer embeddings size
+ * `max_len` is the maximum number of patches
+ """
super().__init__()
+ # Positional embeddings for each location
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
def __call__(self, x: torch.Tensor):
+ """
+ * `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
+ """
+ # Get the positional embeddings for the given patches
pe = self.positional_encodings[x.shape[0]]
+ # Add to patch embeddings and return
return x + pe
class ClassificationHead(Module):
+ """
+
+ ## MLP Classification Head
+
+
+ This is the two layer MLP head to classify the image based on `[CLS]` token embedding.
+ """
def __init__(self, d_model: int, n_hidden: int, n_classes: int):
+ """
+ * `d_model` is the transformer embedding size
+ * `n_hidden` is the size of the hidden layer
+ * `n_classes` is the number of classes in the classification task
+ """
super().__init__()
- self.ln = nn.LayerNorm([d_model])
+ # First layer
self.linear1 = nn.Linear(d_model, n_hidden)
+ # Activation
self.act = nn.ReLU()
+ # Second layer
self.linear2 = nn.Linear(n_hidden, n_classes)
def __call__(self, x: torch.Tensor):
- x = self.ln(x)
+ """
+ * `x` is the transformer encoding for `[CLS]` token
+ """
+ # First layer and activation
x = self.act(self.linear1(x))
+ # Second layer
x = self.linear2(x)
+ #
return x
class VisionTransformer(Module):
+ """
+ ## Vision Transformer
+
+ This combines the [patch embeddings](#PatchEmbeddings),
+ [positional embeddings](#LearnedPositionalEmbeddings),
+ transformer and the [classification head](#ClassificationHead).
+ """
def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
classification: ClassificationHead):
+ """
+ * `transformer_layer` is a copy of a single [transformer layer](../models.html#TransformerLayer).
+ We make copies of it to make the transformer with `n_layers`.
+ * `n_layers` is the number of [transformer layers]((../models.html#TransformerLayer).
+ * `patch_emb` is the [patch embeddings layer](#PatchEmbeddings).
+ * `pos_emb` is the [positional embeddings layer](#LearnedPositionalEmbeddings).
+ * `classification` is the [classification head](#ClassificationHead).
+ """
super().__init__()
- # Make copies of the transformer layer
- self.classification = classification
- self.pos_emb = pos_emb
+ # Patch embeddings
self.patch_emb = patch_emb
+ self.pos_emb = pos_emb
+ # Classification head
+ self.classification = classification
+ # Make copies of the transformer layer
self.transformer_layers = clone_module_list(transformer_layer, n_layers)
+ # `[CLS]` token embedding
self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
+ # Final normalization layer
+ self.ln = nn.LayerNorm([transformer_layer.size])
- def __call__(self, x):
+ def __call__(self, x: torch.Tensor):
+ """
+ * `x` is the input image of shape `[batch_size, channels, height, width]`
+ """
+ # Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]`
x = self.patch_emb(x)
+ # Add positional embeddings
x = self.pos_emb(x)
+ # Concatenate the `[CLS]` token embeddings before feeding the transformer
cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
x = torch.cat([cls_token_emb, x])
+
+ # Pass through transformer layers with no attention masking
for layer in self.transformer_layers:
x = layer(x=x, mask=None)
+ # Get the transformer output of the `[CLS]` token (which is the first in the sequence).
x = x[0]
+ # Layer normalization
+ x = self.ln(x)
+
+ # Classification head, to get logits
x = self.classification(x)
+ #
return x
diff --git a/labml_nn/transformers/vit/experiment.py b/labml_nn/transformers/vit/experiment.py
index d8655897..febcb186 100644
--- a/labml_nn/transformers/vit/experiment.py
+++ b/labml_nn/transformers/vit/experiment.py
@@ -1,11 +1,13 @@
"""
---
-title: Train a ViT on CIFAR 10
+title: Train a Vision Transformer (ViT) on CIFAR 10
summary: >
- Train a ViT on CIFAR 10
+ Train a Vision Transformer (ViT) on CIFAR 10
---
-# Train a ViT on CIFAR 10
+# Train a [Vision Transformer (ViT)](index.html) on CIFAR 10
+
+[](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
"""
from labml import experiment
@@ -18,19 +20,27 @@ class Configs(CIFAR10Configs):
"""
## Configurations
- We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
+ We use [`CIFAR10Configs`](../../experiments/cifar10.html) which defines all the
dataset related configurations, optimizer, and a training loop.
"""
+ # [Transformer configurations](../configs.html#TransformerConfigs)
+ # to get [transformer layer](../models.html#TransformerLayer)
transformer: TransformerConfigs
+ # Size of a patch
patch_size: int = 4
- n_hidden: int = 2048
+ # Size of the hidden layer in classification head
+ n_hidden_classification: int = 2048
+ # Number of classes in the task
n_classes: int = 10
@option(Configs.transformer)
-def _transformer(c: Configs):
+def _transformer():
+ """
+ Create transformer configs
+ """
return TransformerConfigs()
@@ -42,11 +52,13 @@ def _vit(c: Configs):
from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
PatchEmbeddings
+ # Transformer size from [Transformer configurations](../configs.html#TransformerConfigs)
d_model = c.transformer.d_model
+ # Create a vision transformer
return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
PatchEmbeddings(d_model, c.patch_size, 3),
LearnedPositionalEmbeddings(d_model),
- ClassificationHead(d_model, c.n_hidden, c.n_classes)).to(c.device)
+ ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
def main():
@@ -56,20 +68,20 @@ def main():
conf = Configs()
# Load configurations
experiment.configs(conf, {
- 'device.cuda_device': 0,
-
- # 'optimizer.optimizer': 'Noam',
- # 'optimizer.learning_rate': 1.,
+ # Optimizer
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 2.5e-4,
- 'optimizer.d_model': 512,
+ # Transformer embedding size
'transformer.d_model': 512,
+ # Training epochs and batch size
'epochs': 1000,
'train_batch_size': 64,
+ # Augment CIFAR 10 images for training
'train_dataset': 'cifar10_train_augmented',
+ # Do not augment CIFAR 10 images for validation
'valid_dataset': 'cifar10_valid_no_augment',
})
# Set model for saving/loading
diff --git a/labml_nn/transformers/vit/readme.md b/labml_nn/transformers/vit/readme.md
new file mode 100644
index 00000000..636ddb0c
--- /dev/null
+++ b/labml_nn/transformers/vit/readme.md
@@ -0,0 +1,32 @@
+# [Vision Transformer (ViT)](https://nn.labml.ai/transformer/vit/index.html)
+
+This is a [PyTorch](https://pytorch.org) implementation of the paper
+[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
+
+Vision transformer applies a pure transformer to images
+without any convolution layers.
+They split the image into patches and apply a transformer on patch embeddings.
+[Patch embeddings](https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings) are generated by applying a simple linear transformation
+to the flattened pixel values of the patch.
+Then a standard transformer encoder is fed with the patch embeddings, along with a
+classification token `[CLS]`.
+The encoding on the `[CLS]` token is used to classify the image with an MLP.
+
+When feeding the transformer with the patches, learned positional embeddings are
+added to the patch embeddings, because the patch embeddings do not have any information
+about where that patch is from.
+The positional embeddings are a set of vectors for each patch location that get trained
+with gradient descent along with other parameters.
+
+ViTs perform well when they are pre-trained on large datasets.
+The paper suggests pre-training them with an MLP classification head and
+then using a single linear layer when fine-tuning.
+The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
+They also use higher resolution images during inference while keeping the
+patch size the same.
+The positional embeddings for new patch locations are calculated by interpolating
+learning positional embeddings.
+
+Here's [an experiment](https://nn.labml.ai/transformer/vit/experiment.html) that trains ViT on CIFAR-10.
+This doesn't do very well because it's trained on a small dataset.
+It's a simple experiment that anyone can run and play with ViTs.
diff --git a/readme.md b/readme.md
index 21601c63..64b505dd 100644
--- a/readme.md
+++ b/readme.md
@@ -37,6 +37,7 @@ implementations almost weekly.
* [Masked Language Model](https://nn.labml.ai/transformers/mlm/index.html)
* [MLP-Mixer: An all-MLP Architecture for Vision](https://nn.labml.ai/transformers/mlp_mixer/index.html)
* [Pay Attention to MLPs (gMLP)](https://nn.labml.ai/transformers/gmlp/index.html)
+* [Vision Transformer (ViT)](https://nn.labml.ai/transformers/vit/index.html)
#### ✨ [Recurrent Highway Networks](https://nn.labml.ai/recurrent_highway_networks/index.html)
diff --git a/setup.py b/setup.py
index fbc15d2d..4b41b521 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
- version='0.4.102',
+ version='0.4.103',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="A collection of PyTorch implementations of neural network architectures and layers.",