mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 04:37:46 +08:00
vit
This commit is contained in:
@ -31,6 +31,7 @@ implementations.
|
||||
* [Masked Language Model](transformers/mlm/index.html)
|
||||
* [MLP-Mixer: An all-MLP Architecture for Vision](transformers/mlp_mixer/index.html)
|
||||
* [Pay Attention to MLPs (gMLP)](transformers/gmlp/index.html)
|
||||
* [Vision Transformer (ViT)](transformers/vit/index.html)
|
||||
|
||||
#### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html)
|
||||
|
||||
|
||||
@ -82,6 +82,11 @@ This is an implementation of the paper
|
||||
|
||||
This is an implementation of the paper
|
||||
[Pay Attention to MLPs](https://papers.labml.ai/paper/2105.08050).
|
||||
|
||||
## [Vision Transformer (ViT)](vit/index.html)
|
||||
|
||||
This is an implementation of the paper
|
||||
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
|
||||
"""
|
||||
|
||||
from .configs import TransformerConfigs
|
||||
|
||||
@ -1,3 +1,47 @@
|
||||
"""
|
||||
---
|
||||
title: Vision Transformer (ViT)
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of the paper
|
||||
"An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale"
|
||||
---
|
||||
|
||||
# Vision Transformer (ViT)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
||||
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
|
||||
|
||||
Vision transformer applies a pure transformer to images
|
||||
without any convolution layers.
|
||||
They split the image into patches and apply a transformer on patch embeddings.
|
||||
[Patch embeddings](#PathEmbeddings) are generated by applying a simple linear transformation
|
||||
to the flattened pixel values of the patch.
|
||||
Then a standard transformer encoder is fed with the patch embeddings, along with a
|
||||
classification token `[CLS]`.
|
||||
The encoding on the `[CLS]` token is used to classify the image with an MLP.
|
||||
|
||||
When feeding the transformer with the patches, learned positional embeddings are
|
||||
added to the patch embeddings, because the patch embeddings do not have any information
|
||||
about where that patch is from.
|
||||
The positional embeddings are a set of vectors for each patch location that get trained
|
||||
with gradient descent along with other parameters.
|
||||
|
||||
ViTs perform well when they are pre-trained on large datasets.
|
||||
The paper suggests pre-training them with an MLP classification head and
|
||||
then using a single linear layer when fine-tuning.
|
||||
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
|
||||
They also use higher resolution images during inference while keeping the
|
||||
patch size the same.
|
||||
The positional embeddings for new patch locations are calculated by interpolating
|
||||
learning positional embeddings.
|
||||
|
||||
Here's [an experiment](experiment.html) that trains ViT on CIFAR-10.
|
||||
This doesn't do very well because it's trained on a small dataset.
|
||||
It's a simple experiment that anyone can run and play with ViTs.
|
||||
|
||||
[](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -9,24 +53,41 @@ from labml_nn.utils import clone_module_list
|
||||
class PatchEmbeddings(Module):
|
||||
"""
|
||||
<a id="PatchEmbeddings">
|
||||
## Embed patches
|
||||
## Get patch embeddings
|
||||
</a>
|
||||
|
||||
The paper splits the image into patches of equal size and do a linear transformation
|
||||
on the flattened pixels for each patch.
|
||||
|
||||
We implement the same thing through a convolution layer, because it's simpler to implement.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, patch_size: int, in_channels: int):
|
||||
"""
|
||||
* `d_model` is the transformer embeddings size
|
||||
* `patch_size` is the size of the patch
|
||||
* `in_channels` is the number of channels in the input image (3 for rgb)
|
||||
"""
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
# We create a convolution layer with a kernel size and and stride length equal to patch size.
|
||||
# This is equivalent to splitting the image into patches and doing a linear
|
||||
# transformation on each patch.
|
||||
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
x has shape `[batch_size, channels, height, width]`
|
||||
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Apply convolution layer
|
||||
x = self.conv(x)
|
||||
# Get the shape.
|
||||
bs, c, h, w = x.shape
|
||||
# Rearrange to shape `[patches, batch_size, d_model]`
|
||||
x = x.permute(2, 3, 0, 1)
|
||||
x = x.view(h * w, bs, c)
|
||||
|
||||
# Return the patch embeddings
|
||||
return x
|
||||
|
||||
|
||||
@ -35,56 +96,121 @@ class LearnedPositionalEmbeddings(Module):
|
||||
<a id="LearnedPositionalEmbeddings">
|
||||
## Add parameterized positional encodings
|
||||
</a>
|
||||
|
||||
This adds learned positional embeddings to patch embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, max_len: int = 5_000):
|
||||
"""
|
||||
* `d_model` is the transformer embeddings size
|
||||
* `max_len` is the maximum number of patches
|
||||
"""
|
||||
super().__init__()
|
||||
# Positional embeddings for each location
|
||||
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
|
||||
"""
|
||||
# Get the positional embeddings for the given patches
|
||||
pe = self.positional_encodings[x.shape[0]]
|
||||
# Add to patch embeddings and return
|
||||
return x + pe
|
||||
|
||||
|
||||
class ClassificationHead(Module):
|
||||
"""
|
||||
<a id="ClassificationHead">
|
||||
## MLP Classification Head
|
||||
</a>
|
||||
|
||||
This is the two layer MLP head to classify the image based on `[CLS]` token embedding.
|
||||
"""
|
||||
def __init__(self, d_model: int, n_hidden: int, n_classes: int):
|
||||
"""
|
||||
* `d_model` is the transformer embedding size
|
||||
* `n_hidden` is the size of the hidden layer
|
||||
* `n_classes` is the number of classes in the classification task
|
||||
"""
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm([d_model])
|
||||
# First layer
|
||||
self.linear1 = nn.Linear(d_model, n_hidden)
|
||||
# Activation
|
||||
self.act = nn.ReLU()
|
||||
# Second layer
|
||||
self.linear2 = nn.Linear(n_hidden, n_classes)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
x = self.ln(x)
|
||||
"""
|
||||
* `x` is the transformer encoding for `[CLS]` token
|
||||
"""
|
||||
# First layer and activation
|
||||
x = self.act(self.linear1(x))
|
||||
# Second layer
|
||||
x = self.linear2(x)
|
||||
|
||||
#
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(Module):
|
||||
"""
|
||||
## Vision Transformer
|
||||
|
||||
This combines the [patch embeddings](#PatchEmbeddings),
|
||||
[positional embeddings](#LearnedPositionalEmbeddings),
|
||||
transformer and the [classification head](#ClassificationHead).
|
||||
"""
|
||||
def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
|
||||
patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
|
||||
classification: ClassificationHead):
|
||||
"""
|
||||
* `transformer_layer` is a copy of a single [transformer layer](../models.html#TransformerLayer).
|
||||
We make copies of it to make the transformer with `n_layers`.
|
||||
* `n_layers` is the number of [transformer layers]((../models.html#TransformerLayer).
|
||||
* `patch_emb` is the [patch embeddings layer](#PatchEmbeddings).
|
||||
* `pos_emb` is the [positional embeddings layer](#LearnedPositionalEmbeddings).
|
||||
* `classification` is the [classification head](#ClassificationHead).
|
||||
"""
|
||||
super().__init__()
|
||||
# Make copies of the transformer layer
|
||||
self.classification = classification
|
||||
self.pos_emb = pos_emb
|
||||
# Patch embeddings
|
||||
self.patch_emb = patch_emb
|
||||
self.pos_emb = pos_emb
|
||||
# Classification head
|
||||
self.classification = classification
|
||||
# Make copies of the transformer layer
|
||||
self.transformer_layers = clone_module_list(transformer_layer, n_layers)
|
||||
|
||||
# `[CLS]` token embedding
|
||||
self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
|
||||
# Final normalization layer
|
||||
self.ln = nn.LayerNorm([transformer_layer.size])
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the input image of shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
# Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]`
|
||||
x = self.patch_emb(x)
|
||||
# Add positional embeddings
|
||||
x = self.pos_emb(x)
|
||||
# Concatenate the `[CLS]` token embeddings before feeding the transformer
|
||||
cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
|
||||
x = torch.cat([cls_token_emb, x])
|
||||
|
||||
# Pass through transformer layers with no attention masking
|
||||
for layer in self.transformer_layers:
|
||||
x = layer(x=x, mask=None)
|
||||
|
||||
# Get the transformer output of the `[CLS]` token (which is the first in the sequence).
|
||||
x = x[0]
|
||||
|
||||
# Layer normalization
|
||||
x = self.ln(x)
|
||||
|
||||
# Classification head, to get logits
|
||||
x = self.classification(x)
|
||||
|
||||
#
|
||||
return x
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
"""
|
||||
---
|
||||
title: Train a ViT on CIFAR 10
|
||||
title: Train a Vision Transformer (ViT) on CIFAR 10
|
||||
summary: >
|
||||
Train a ViT on CIFAR 10
|
||||
Train a Vision Transformer (ViT) on CIFAR 10
|
||||
---
|
||||
|
||||
# Train a ViT on CIFAR 10
|
||||
# Train a [Vision Transformer (ViT)](index.html) on CIFAR 10
|
||||
|
||||
[](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
|
||||
"""
|
||||
|
||||
from labml import experiment
|
||||
@ -18,19 +20,27 @@ class Configs(CIFAR10Configs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
|
||||
We use [`CIFAR10Configs`](../../experiments/cifar10.html) which defines all the
|
||||
dataset related configurations, optimizer, and a training loop.
|
||||
"""
|
||||
|
||||
# [Transformer configurations](../configs.html#TransformerConfigs)
|
||||
# to get [transformer layer](../models.html#TransformerLayer)
|
||||
transformer: TransformerConfigs
|
||||
|
||||
# Size of a patch
|
||||
patch_size: int = 4
|
||||
n_hidden: int = 2048
|
||||
# Size of the hidden layer in classification head
|
||||
n_hidden_classification: int = 2048
|
||||
# Number of classes in the task
|
||||
n_classes: int = 10
|
||||
|
||||
|
||||
@option(Configs.transformer)
|
||||
def _transformer(c: Configs):
|
||||
def _transformer():
|
||||
"""
|
||||
Create transformer configs
|
||||
"""
|
||||
return TransformerConfigs()
|
||||
|
||||
|
||||
@ -42,11 +52,13 @@ def _vit(c: Configs):
|
||||
from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
|
||||
PatchEmbeddings
|
||||
|
||||
# Transformer size from [Transformer configurations](../configs.html#TransformerConfigs)
|
||||
d_model = c.transformer.d_model
|
||||
# Create a vision transformer
|
||||
return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
|
||||
PatchEmbeddings(d_model, c.patch_size, 3),
|
||||
LearnedPositionalEmbeddings(d_model),
|
||||
ClassificationHead(d_model, c.n_hidden, c.n_classes)).to(c.device)
|
||||
ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
@ -56,20 +68,20 @@ def main():
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'device.cuda_device': 0,
|
||||
|
||||
# 'optimizer.optimizer': 'Noam',
|
||||
# 'optimizer.learning_rate': 1.,
|
||||
# Optimizer
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 2.5e-4,
|
||||
'optimizer.d_model': 512,
|
||||
|
||||
# Transformer embedding size
|
||||
'transformer.d_model': 512,
|
||||
|
||||
# Training epochs and batch size
|
||||
'epochs': 1000,
|
||||
'train_batch_size': 64,
|
||||
|
||||
# Augment CIFAR 10 images for training
|
||||
'train_dataset': 'cifar10_train_augmented',
|
||||
# Do not augment CIFAR 10 images for validation
|
||||
'valid_dataset': 'cifar10_valid_no_augment',
|
||||
})
|
||||
# Set model for saving/loading
|
||||
|
||||
32
labml_nn/transformers/vit/readme.md
Normal file
32
labml_nn/transformers/vit/readme.md
Normal file
@ -0,0 +1,32 @@
|
||||
# [Vision Transformer (ViT)](https://nn.labml.ai/transformer/vit/index.html)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
||||
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
|
||||
|
||||
Vision transformer applies a pure transformer to images
|
||||
without any convolution layers.
|
||||
They split the image into patches and apply a transformer on patch embeddings.
|
||||
[Patch embeddings](https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings) are generated by applying a simple linear transformation
|
||||
to the flattened pixel values of the patch.
|
||||
Then a standard transformer encoder is fed with the patch embeddings, along with a
|
||||
classification token `[CLS]`.
|
||||
The encoding on the `[CLS]` token is used to classify the image with an MLP.
|
||||
|
||||
When feeding the transformer with the patches, learned positional embeddings are
|
||||
added to the patch embeddings, because the patch embeddings do not have any information
|
||||
about where that patch is from.
|
||||
The positional embeddings are a set of vectors for each patch location that get trained
|
||||
with gradient descent along with other parameters.
|
||||
|
||||
ViTs perform well when they are pre-trained on large datasets.
|
||||
The paper suggests pre-training them with an MLP classification head and
|
||||
then using a single linear layer when fine-tuning.
|
||||
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
|
||||
They also use higher resolution images during inference while keeping the
|
||||
patch size the same.
|
||||
The positional embeddings for new patch locations are calculated by interpolating
|
||||
learning positional embeddings.
|
||||
|
||||
Here's [an experiment](https://nn.labml.ai/transformer/vit/experiment.html) that trains ViT on CIFAR-10.
|
||||
This doesn't do very well because it's trained on a small dataset.
|
||||
It's a simple experiment that anyone can run and play with ViTs.
|
||||
Reference in New Issue
Block a user