diff --git a/docs/index.html b/docs/index.html index 0f724160..6d8df374 100644 --- a/docs/index.html +++ b/docs/index.html @@ -95,6 +95,7 @@ implementations.

  • Masked Language Model
  • MLP-Mixer: An all-MLP Architecture for Vision
  • Pay Attention to MLPs (gMLP)
  • +
  • Vision Transformer (ViT)
  • Recurrent Highway Networks

    LSTM

    diff --git a/docs/transformers/index.html b/docs/transformers/index.html index 1c2d646b..52ba8c9a 100644 --- a/docs/transformers/index.html +++ b/docs/transformers/index.html @@ -117,12 +117,15 @@ It does single GPU training but we implement the concept of switching as describ

    Pay Attention to MLPs (gMLP)

    This is an implementation of the paper Pay Attention to MLPs.

    +

    Vision Transformer (ViT)

    +

    This is an implementation of the paper +An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale.

    -
    87from .configs import TransformerConfigs
    -88from .models import TransformerLayer, Encoder, Decoder, Generator, EncoderDecoder
    -89from .mha import MultiHeadAttention
    -90from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
    +
    92from .configs import TransformerConfigs
    +93from .models import TransformerLayer, Encoder, Decoder, Generator, EncoderDecoder
    +94from .mha import MultiHeadAttention
    +95from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
    diff --git a/docs/transformers/vit/experiment.html b/docs/transformers/vit/experiment.html index 8333bfdc..02d3c5c5 100644 --- a/docs/transformers/vit/experiment.html +++ b/docs/transformers/vit/experiment.html @@ -3,24 +3,24 @@ - + - - + + - + - - + + - Train a ViT on CIFAR 10 + Train a Vision Transformer (ViT) on CIFAR 10 @@ -67,13 +67,14 @@ -

    Train a ViT on CIFAR 10

    +

    Train a Vision Transformer (ViT) on CIFAR 10

    +

    View Run

    -
    11from labml import experiment
    -12from labml.configs import option
    -13from labml_nn.experiments.cifar10 import CIFAR10Configs
    -14from labml_nn.transformers import TransformerConfigs
    +
    13from labml import experiment
    +14from labml.configs import option
    +15from labml_nn.experiments.cifar10 import CIFAR10Configs
    +16from labml_nn.transformers import TransformerConfigs
    @@ -82,11 +83,11 @@ #

    Configurations

    -

    We use CIFAR10Configs which defines all the +

    We use CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

    -
    17class Configs(CIFAR10Configs):
    +
    19class Configs(CIFAR10Configs):
    @@ -94,31 +95,22 @@ dataset related configurations, optimizer, and a training loop.

    - +

    Transformer configurations +to get transformer layer

    -
    25    transformer: TransformerConfigs
    -26
    -27    patch_size: int = 4
    -28    n_hidden: int = 2048
    -29    n_classes: int = 10
    +
    29    transformer: TransformerConfigs
    -
    +
    -

    Create model

    +

    Size of a patch

    -
    32@option(Configs.transformer)
    -33def _transformer(c: Configs):
    -34    return TransformerConfigs()
    -35
    -36
    -37@option(Configs.model)
    -38def _vit(c: Configs):
    +
    32    patch_size: int = 4
    @@ -126,17 +118,10 @@ dataset related configurations, optimizer, and a training loop.

    - +

    Size of the hidden layer in classification head

    -
    42    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
    -43        PatchEmbeddings
    -44
    -45    d_model = c.transformer.d_model
    -46    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
    -47                             PatchEmbeddings(d_model, c.patch_size, 3),
    -48                             LearnedPositionalEmbeddings(d_model),
    -49                             ClassificationHead(d_model, c.n_hidden, c.n_classes)).to(c.device)
    +
    34    n_hidden_classification: int = 2048
    @@ -144,21 +129,22 @@ dataset related configurations, optimizer, and a training loop.

    - +

    Number of classes in the task

    -
    52def main():
    +
    36    n_classes: int = 10
    -
    +
    -

    Create experiment

    +

    Create transformer configs

    -
    54    experiment.create(name='ViT', comment='cifar10')
    +
    39@option(Configs.transformer)
    +40def _transformer():
    @@ -166,22 +152,22 @@ dataset related configurations, optimizer, and a training loop.

    -

    Create configurations

    +
    -
    56    conf = Configs()
    +
    44    return TransformerConfigs()
    -
    +
    -

    Load configurations

    +

    Create model

    -
    58    experiment.configs(conf, {
    -59        'device.cuda_device': 0,
    +
    47@option(Configs.model)
    +48def _vit(c: Configs):
    @@ -189,22 +175,11 @@ dataset related configurations, optimizer, and a training loop.

    -

    ‘optimizer.optimizer’: ‘Noam’, -‘optimizer.learning_rate’: 1.,

    +
    -
    63        'optimizer.optimizer': 'Adam',
    -64        'optimizer.learning_rate': 2.5e-4,
    -65        'optimizer.d_model': 512,
    -66
    -67        'transformer.d_model': 512,
    -68
    -69        'epochs': 1000,
    -70        'train_batch_size': 64,
    -71
    -72        'train_dataset': 'cifar10_train_augmented',
    -73        'valid_dataset': 'cifar10_valid_no_augment',
    -74    })
    +
    52    from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
    +53        PatchEmbeddings
    @@ -212,10 +187,10 @@ dataset related configurations, optimizer, and a training loop.

    -

    Set model for saving/loading

    +

    Transformer size from Transformer configurations

    -
    76    experiment.add_pytorch_models({'model': conf.model})
    +
    56    d_model = c.transformer.d_model
    @@ -223,11 +198,13 @@ dataset related configurations, optimizer, and a training loop.

    -

    Start the experiment and run the training loop

    +

    Create a vision transformer

    -
    78    with experiment.start():
    -79        conf.run()
    +
    58    return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
    +59                             PatchEmbeddings(d_model, c.patch_size, 3),
    +60                             LearnedPositionalEmbeddings(d_model),
    +61                             ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device)
    @@ -238,8 +215,133 @@ dataset related configurations, optimizer, and a training loop.

    -
    83if __name__ == '__main__':
    -84    main()
    +
    64def main():
    +
    + +
    +
    + +

    Create experiment

    +
    +
    +
    66    experiment.create(name='ViT', comment='cifar10')
    +
    +
    +
    +
    + +

    Create configurations

    +
    +
    +
    68    conf = Configs()
    +
    +
    +
    +
    + +

    Load configurations

    +
    +
    +
    70    experiment.configs(conf, {
    +
    +
    +
    +
    + +

    Optimizer

    +
    +
    +
    72        'optimizer.optimizer': 'Adam',
    +73        'optimizer.learning_rate': 2.5e-4,
    +
    +
    +
    +
    + +

    Transformer embedding size

    +
    +
    +
    76        'transformer.d_model': 512,
    +
    +
    +
    +
    + +

    Training epochs and batch size

    +
    +
    +
    79        'epochs': 1000,
    +80        'train_batch_size': 64,
    +
    +
    +
    +
    + +

    Augment CIFAR 10 images for training

    +
    +
    +
    83        'train_dataset': 'cifar10_train_augmented',
    +
    +
    +
    +
    + +

    Do not augment CIFAR 10 images for validation

    +
    +
    +
    85        'valid_dataset': 'cifar10_valid_no_augment',
    +86    })
    +
    +
    +
    +
    + +

    Set model for saving/loading

    +
    +
    +
    88    experiment.add_pytorch_models({'model': conf.model})
    +
    +
    +
    +
    + +

    Start the experiment and run the training loop

    +
    +
    +
    90    with experiment.start():
    +91        conf.run()
    +
    +
    +
    +
    + + +
    +
    +
    95if __name__ == '__main__':
    +96    main()
    diff --git a/docs/transformers/vit/index.html b/docs/transformers/vit/index.html index a25d9121..5b4dbd0b 100644 --- a/docs/transformers/vit/index.html +++ b/docs/transformers/vit/index.html @@ -3,24 +3,24 @@ - + - - + + - + - - + + - __init__.py + Vision Transformer (ViT) @@ -63,19 +63,46 @@
    -
    +
    - +

    Vision Transformer (ViT)

    +

    This is a PyTorch implementation of the paper +An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale.

    +

    Vision transformer applies a pure transformer to images +without any convolution layers. +They split the image into patches and apply a transformer on patch embeddings. +Patch embeddings are generated by applying a simple linear transformation +to the flattened pixel values of the patch. +Then a standard transformer encoder is fed with the patch embeddings, along with a +classification token [CLS]. +The encoding on the [CLS] token is used to classify the image with an MLP.

    +

    When feeding the transformer with the patches, learned positional embeddings are +added to the patch embeddings, because the patch embeddings do not have any information +about where that patch is from. +The positional embeddings are a set of vectors for each patch location that get trained +with gradient descent along with other parameters.

    +

    ViTs perform well when they are pre-trained on large datasets. +The paper suggests pre-training them with an MLP classification head and +then using a single linear layer when fine-tuning. +The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. +They also use higher resolution images during inference while keeping the +patch size the same. +The positional embeddings for new patch locations are calculated by interpolating +learning positional embeddings.

    +

    Here’s an experiment that trains ViT on CIFAR-10. +This doesn’t do very well because it’s trained on a small dataset. +It’s a simple experiment that anyone can run and play with ViTs.

    +

    View Run

    -
    1import torch
    -2from torch import nn
    -3
    -4from labml_helpers.module import Module
    -5from labml_nn.transformers import TransformerLayer
    -6from labml_nn.utils import clone_module_list
    +
    45import torch
    +46from torch import nn
    +47
    +48from labml_helpers.module import Module
    +49from labml_nn.transformers import TransformerLayer
    +50from labml_nn.utils import clone_module_list
    @@ -84,36 +111,40 @@ #

    -

    Embed patches

    +

    Get patch embeddings

    +

    The paper splits the image into patches of equal size and do a linear transformation +on the flattened pixels for each patch.

    +

    We implement the same thing through a convolution layer, because it’s simpler to implement.

    -
    9class PatchEmbeddings(Module):
    +
    53class PatchEmbeddings(Module):
    -
    +
    +
      +
    • d_model is the transformer embeddings size
    • +
    • patch_size is the size of the patch
    • +
    • in_channels is the number of channels in the input image (3 for rgb)
    • +
    +
    +
    +
    65    def __init__(self, d_model: int, patch_size: int, in_channels: int):
    +
    +
    +
    +
    +
    -
    16    def __init__(self, d_model: int, patch_size: int, in_channels: int):
    -17        super().__init__()
    -18        self.patch_size = patch_size
    -19        self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
    -
    -
    -
    -
    - -

    x has shape [batch_size, channels, height, width]

    -
    -
    -
    21    def __call__(self, x: torch.Tensor):
    +
    71        super().__init__()
    @@ -121,15 +152,12 @@ - +

    We create a convolution layer with a kernel size and and stride length equal to patch size. +This is equivalent to splitting the image into patches and doing a linear +transformation on each patch.

    -
    25        x = self.conv(x)
    -26        bs, c, h, w = x.shape
    -27        x = x.permute(2, 3, 0, 1)
    -28        x = x.view(h * w, bs, c)
    -29
    -30        return x
    +
    76        self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
    @@ -137,12 +165,12 @@ -

    -

    Add parameterized positional encodings

    -

    +
    -
    33class LearnedPositionalEmbeddings(Module):
    +
    78    def __call__(self, x: torch.Tensor):
    @@ -150,12 +178,10 @@ - +

    Apply convolution layer

    -
    40    def __init__(self, d_model: int, max_len: int = 5_000):
    -41        super().__init__()
    -42        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
    +
    83        x = self.conv(x)
    @@ -163,12 +189,10 @@ - +

    Get the shape.

    -
    44    def __call__(self, x: torch.Tensor):
    -45        pe = self.positional_encodings[x.shape[0]]
    -46        return x + pe
    +
    85        bs, c, h, w = x.shape
    @@ -176,10 +200,11 @@ - +

    Rearrange to shape [patches, batch_size, d_model]

    -
    49class ClassificationHead(Module):
    +
    87        x = x.permute(2, 3, 0, 1)
    +88        x = x.view(h * w, bs, c)
    @@ -187,42 +212,38 @@ - +

    Return the patch embeddings

    -
    50    def __init__(self, d_model: int, n_hidden: int, n_classes: int):
    -51        super().__init__()
    -52        self.ln = nn.LayerNorm([d_model])
    -53        self.linear1 = nn.Linear(d_model, n_hidden)
    -54        self.act = nn.ReLU()
    -55        self.linear2 = nn.Linear(n_hidden, n_classes)
    +
    91        return x
    -
    +
    - +

    +

    Add parameterized positional encodings

    +

    +

    This adds learned positional embeddings to patch embeddings.

    -
    57    def __call__(self, x: torch.Tensor):
    -58        x = self.ln(x)
    -59        x = self.act(self.linear1(x))
    -60        x = self.linear2(x)
    -61
    -62        return x
    +
    94class LearnedPositionalEmbeddings(Module):
    -
    +
    - +
      +
    • d_model is the transformer embeddings size
    • +
    • max_len is the maximum number of patches
    • +
    -
    65class VisionTransformer(Module):
    +
    103    def __init__(self, d_model: int, max_len: int = 5_000):
    @@ -233,10 +254,7 @@
    -
    66    def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
    -67                 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
    -68                 classification: ClassificationHead):
    -69        super().__init__()
    +
    108        super().__init__()
    @@ -244,38 +262,368 @@ -

    Make copies of the transformer layer

    +

    Positional embeddings for each location

    -
    71        self.classification = classification
    -72        self.pos_emb = pos_emb
    -73        self.patch_emb = patch_emb
    -74        self.transformer_layers = clone_module_list(transformer_layer, n_layers)
    -75
    -76        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
    +
    110        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
    -
    +
    +
      +
    • x is the patch embeddings of shape [patches, batch_size, d_model]
    • +
    +
    +
    +
    112    def __call__(self, x: torch.Tensor):
    +
    +
    +
    +
    + +

    Get the positional embeddings for the given patches

    +
    +
    +
    117        pe = self.positional_encodings[x.shape[0]]
    +
    +
    +
    +
    + +

    Add to patch embeddings and return

    +
    +
    +
    119        return x + pe
    +
    +
    +
    +
    + +

    +

    MLP Classification Head

    +

    +

    This is the two layer MLP head to classify the image based on [CLS] token embedding.

    +
    +
    +
    122class ClassificationHead(Module):
    +
    +
    +
    +
    + +
      +
    • d_model is the transformer embedding size
    • +
    • n_hidden is the size of the hidden layer
    • +
    • n_classes is the number of classes in the classification task
    • +
    +
    +
    +
    130    def __init__(self, d_model: int, n_hidden: int, n_classes: int):
    +
    +
    +
    +
    +
    -
    78    def __call__(self, x):
    -79        x = self.patch_emb(x)
    -80        x = self.pos_emb(x)
    -81        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
    -82        x = torch.cat([cls_token_emb, x])
    -83        for layer in self.transformer_layers:
    -84            x = layer(x=x, mask=None)
    -85
    -86        x = x[0]
    -87
    -88        x = self.classification(x)
    -89
    -90        return x
    +
    136        super().__init__()
    +
    +
    +
    +
    + +

    First layer

    +
    +
    +
    138        self.linear1 = nn.Linear(d_model, n_hidden)
    +
    +
    +
    +
    + +

    Activation

    +
    +
    +
    140        self.act = nn.ReLU()
    +
    +
    +
    +
    + +

    Second layer

    +
    +
    +
    142        self.linear2 = nn.Linear(n_hidden, n_classes)
    +
    +
    +
    +
    + +
      +
    • x is the transformer encoding for [CLS] token
    • +
    +
    +
    +
    144    def __call__(self, x: torch.Tensor):
    +
    +
    +
    +
    + +

    First layer and activation

    +
    +
    +
    149        x = self.act(self.linear1(x))
    +
    +
    +
    +
    + +

    Second layer

    +
    +
    +
    151        x = self.linear2(x)
    +
    +
    +
    +
    + + +
    +
    +
    154        return x
    +
    +
    +
    +
    + +

    Vision Transformer

    +

    This combines the patch embeddings, +positional embeddings, +transformer and the classification head.

    +
    +
    +
    157class VisionTransformer(Module):
    +
    +
    +
    +
    + + +
    +
    +
    165    def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
    +166                 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
    +167                 classification: ClassificationHead):
    +
    +
    +
    +
    + + +
    +
    +
    176        super().__init__()
    +
    +
    +
    +
    + +

    Patch embeddings

    +
    +
    +
    178        self.patch_emb = patch_emb
    +179        self.pos_emb = pos_emb
    +
    +
    +
    +
    + +

    Classification head

    +
    +
    +
    181        self.classification = classification
    +
    +
    +
    +
    + +

    Make copies of the transformer layer

    +
    +
    +
    183        self.transformer_layers = clone_module_list(transformer_layer, n_layers)
    +
    +
    +
    +
    + +

    [CLS] token embedding

    +
    +
    +
    186        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
    +
    +
    +
    +
    + +

    Final normalization layer

    +
    +
    +
    188        self.ln = nn.LayerNorm([transformer_layer.size])
    +
    +
    +
    +
    + +
      +
    • x is the input image of shape [batch_size, channels, height, width]
    • +
    +
    +
    +
    190    def __call__(self, x: torch.Tensor):
    +
    +
    +
    +
    + +

    Get patch embeddings. This gives a tensor of shape [patches, batch_size, d_model]

    +
    +
    +
    195        x = self.patch_emb(x)
    +
    +
    +
    +
    + +

    Add positional embeddings

    +
    +
    +
    197        x = self.pos_emb(x)
    +
    +
    +
    +
    + +

    Concatenate the [CLS] token embeddings before feeding the transformer

    +
    +
    +
    199        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
    +200        x = torch.cat([cls_token_emb, x])
    +
    +
    +
    +
    + +

    Pass through transformer layers with no attention masking

    +
    +
    +
    203        for layer in self.transformer_layers:
    +204            x = layer(x=x, mask=None)
    +
    +
    +
    +
    + +

    Get the transformer output of the [CLS] token (which is the first in the sequence).

    +
    +
    +
    207        x = x[0]
    +
    +
    +
    +
    + +

    Layer normalization

    +
    +
    +
    210        x = self.ln(x)
    +
    +
    +
    +
    + +

    Classification head, to get logits

    +
    +
    +
    213        x = self.classification(x)
    +
    +
    +
    +
    + + +
    +
    +
    216        return x
    diff --git a/docs/transformers/vit/readme.html b/docs/transformers/vit/readme.html new file mode 100644 index 00000000..a02f2eea --- /dev/null +++ b/docs/transformers/vit/readme.html @@ -0,0 +1,162 @@ + + + + + + + + + + + + + + + + + + + + + + + Vision Transformer (ViT) + + + + + + + + +
    +
    +
    +
    +

    + home + transformers + vit +

    +

    + + + Github + + Twitter +

    +
    +
    +
    +
    + +

    Vision Transformer (ViT)

    +

    This is a PyTorch implementation of the paper +An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale.

    +

    Vision transformer applies a pure transformer to images +without any convolution layers. +They split the image into patches and apply a transformer on patch embeddings. +Patch embeddings are generated by applying a simple linear transformation +to the flattened pixel values of the patch. +Then a standard transformer encoder is fed with the patch embeddings, along with a +classification token [CLS]. +The encoding on the [CLS] token is used to classify the image with an MLP.

    +

    When feeding the transformer with the patches, learned positional embeddings are +added to the patch embeddings, because the patch embeddings do not have any information +about where that patch is from. +The positional embeddings are a set of vectors for each patch location that get trained +with gradient descent along with other parameters.

    +

    ViTs perform well when they are pre-trained on large datasets. +The paper suggests pre-training them with an MLP classification head and +then using a single linear layer when fine-tuning. +The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. +They also use higher resolution images during inference while keeping the +patch size the same. +The positional embeddings for new patch locations are calculated by interpolating +learning positional embeddings.

    +

    Here’s an experiment that trains ViT on CIFAR-10. +This doesn’t do very well because it’s trained on a small dataset. +It’s a simple experiment that anyone can run and play with ViTs.

    +
    +
    + +
    +
    +
    + + + + + + + \ No newline at end of file diff --git a/labml_nn/__init__.py b/labml_nn/__init__.py index f719fb02..d7202226 100644 --- a/labml_nn/__init__.py +++ b/labml_nn/__init__.py @@ -31,6 +31,7 @@ implementations. * [Masked Language Model](transformers/mlm/index.html) * [MLP-Mixer: An all-MLP Architecture for Vision](transformers/mlp_mixer/index.html) * [Pay Attention to MLPs (gMLP)](transformers/gmlp/index.html) +* [Vision Transformer (ViT)](transformers/vit/index.html) #### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html) diff --git a/labml_nn/transformers/__init__.py b/labml_nn/transformers/__init__.py index d58ed81d..37c2ba99 100644 --- a/labml_nn/transformers/__init__.py +++ b/labml_nn/transformers/__init__.py @@ -82,6 +82,11 @@ This is an implementation of the paper This is an implementation of the paper [Pay Attention to MLPs](https://papers.labml.ai/paper/2105.08050). + +## [Vision Transformer (ViT)](vit/index.html) + +This is an implementation of the paper +[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929). """ from .configs import TransformerConfigs diff --git a/labml_nn/transformers/vit/__init__.py b/labml_nn/transformers/vit/__init__.py index d827192e..d45b8c05 100644 --- a/labml_nn/transformers/vit/__init__.py +++ b/labml_nn/transformers/vit/__init__.py @@ -1,3 +1,47 @@ +""" +--- +title: Vision Transformer (ViT) +summary: > + A PyTorch implementation/tutorial of the paper + "An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale" +--- + +# Vision Transformer (ViT) + +This is a [PyTorch](https://pytorch.org) implementation of the paper +[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929). + +Vision transformer applies a pure transformer to images +without any convolution layers. +They split the image into patches and apply a transformer on patch embeddings. +[Patch embeddings](#PathEmbeddings) are generated by applying a simple linear transformation +to the flattened pixel values of the patch. +Then a standard transformer encoder is fed with the patch embeddings, along with a +classification token `[CLS]`. +The encoding on the `[CLS]` token is used to classify the image with an MLP. + +When feeding the transformer with the patches, learned positional embeddings are +added to the patch embeddings, because the patch embeddings do not have any information +about where that patch is from. +The positional embeddings are a set of vectors for each patch location that get trained +with gradient descent along with other parameters. + +ViTs perform well when they are pre-trained on large datasets. +The paper suggests pre-training them with an MLP classification head and +then using a single linear layer when fine-tuning. +The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. +They also use higher resolution images during inference while keeping the +patch size the same. +The positional embeddings for new patch locations are calculated by interpolating +learning positional embeddings. + +Here's [an experiment](experiment.html) that trains ViT on CIFAR-10. +This doesn't do very well because it's trained on a small dataset. +It's a simple experiment that anyone can run and play with ViTs. + +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f) +""" + import torch from torch import nn @@ -9,24 +53,41 @@ from labml_nn.utils import clone_module_list class PatchEmbeddings(Module): """ - ## Embed patches + ## Get patch embeddings + + The paper splits the image into patches of equal size and do a linear transformation + on the flattened pixels for each patch. + + We implement the same thing through a convolution layer, because it's simpler to implement. """ def __init__(self, d_model: int, patch_size: int, in_channels: int): + """ + * `d_model` is the transformer embeddings size + * `patch_size` is the size of the patch + * `in_channels` is the number of channels in the input image (3 for rgb) + """ super().__init__() - self.patch_size = patch_size + + # We create a convolution layer with a kernel size and and stride length equal to patch size. + # This is equivalent to splitting the image into patches and doing a linear + # transformation on each patch. self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size) def __call__(self, x: torch.Tensor): """ - x has shape `[batch_size, channels, height, width]` + * `x` is the input image of shape `[batch_size, channels, height, width]` """ + # Apply convolution layer x = self.conv(x) + # Get the shape. bs, c, h, w = x.shape + # Rearrange to shape `[patches, batch_size, d_model]` x = x.permute(2, 3, 0, 1) x = x.view(h * w, bs, c) + # Return the patch embeddings return x @@ -35,56 +96,121 @@ class LearnedPositionalEmbeddings(Module): ## Add parameterized positional encodings + + This adds learned positional embeddings to patch embeddings. """ def __init__(self, d_model: int, max_len: int = 5_000): + """ + * `d_model` is the transformer embeddings size + * `max_len` is the maximum number of patches + """ super().__init__() + # Positional embeddings for each location self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True) def __call__(self, x: torch.Tensor): + """ + * `x` is the patch embeddings of shape `[patches, batch_size, d_model]` + """ + # Get the positional embeddings for the given patches pe = self.positional_encodings[x.shape[0]] + # Add to patch embeddings and return return x + pe class ClassificationHead(Module): + """ + + ## MLP Classification Head + + + This is the two layer MLP head to classify the image based on `[CLS]` token embedding. + """ def __init__(self, d_model: int, n_hidden: int, n_classes: int): + """ + * `d_model` is the transformer embedding size + * `n_hidden` is the size of the hidden layer + * `n_classes` is the number of classes in the classification task + """ super().__init__() - self.ln = nn.LayerNorm([d_model]) + # First layer self.linear1 = nn.Linear(d_model, n_hidden) + # Activation self.act = nn.ReLU() + # Second layer self.linear2 = nn.Linear(n_hidden, n_classes) def __call__(self, x: torch.Tensor): - x = self.ln(x) + """ + * `x` is the transformer encoding for `[CLS]` token + """ + # First layer and activation x = self.act(self.linear1(x)) + # Second layer x = self.linear2(x) + # return x class VisionTransformer(Module): + """ + ## Vision Transformer + + This combines the [patch embeddings](#PatchEmbeddings), + [positional embeddings](#LearnedPositionalEmbeddings), + transformer and the [classification head](#ClassificationHead). + """ def __init__(self, transformer_layer: TransformerLayer, n_layers: int, patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings, classification: ClassificationHead): + """ + * `transformer_layer` is a copy of a single [transformer layer](../models.html#TransformerLayer). + We make copies of it to make the transformer with `n_layers`. + * `n_layers` is the number of [transformer layers]((../models.html#TransformerLayer). + * `patch_emb` is the [patch embeddings layer](#PatchEmbeddings). + * `pos_emb` is the [positional embeddings layer](#LearnedPositionalEmbeddings). + * `classification` is the [classification head](#ClassificationHead). + """ super().__init__() - # Make copies of the transformer layer - self.classification = classification - self.pos_emb = pos_emb + # Patch embeddings self.patch_emb = patch_emb + self.pos_emb = pos_emb + # Classification head + self.classification = classification + # Make copies of the transformer layer self.transformer_layers = clone_module_list(transformer_layer, n_layers) + # `[CLS]` token embedding self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True) + # Final normalization layer + self.ln = nn.LayerNorm([transformer_layer.size]) - def __call__(self, x): + def __call__(self, x: torch.Tensor): + """ + * `x` is the input image of shape `[batch_size, channels, height, width]` + """ + # Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]` x = self.patch_emb(x) + # Add positional embeddings x = self.pos_emb(x) + # Concatenate the `[CLS]` token embeddings before feeding the transformer cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1) x = torch.cat([cls_token_emb, x]) + + # Pass through transformer layers with no attention masking for layer in self.transformer_layers: x = layer(x=x, mask=None) + # Get the transformer output of the `[CLS]` token (which is the first in the sequence). x = x[0] + # Layer normalization + x = self.ln(x) + + # Classification head, to get logits x = self.classification(x) + # return x diff --git a/labml_nn/transformers/vit/experiment.py b/labml_nn/transformers/vit/experiment.py index d8655897..febcb186 100644 --- a/labml_nn/transformers/vit/experiment.py +++ b/labml_nn/transformers/vit/experiment.py @@ -1,11 +1,13 @@ """ --- -title: Train a ViT on CIFAR 10 +title: Train a Vision Transformer (ViT) on CIFAR 10 summary: > - Train a ViT on CIFAR 10 + Train a Vision Transformer (ViT) on CIFAR 10 --- -# Train a ViT on CIFAR 10 +# Train a [Vision Transformer (ViT)](index.html) on CIFAR 10 + +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f) """ from labml import experiment @@ -18,19 +20,27 @@ class Configs(CIFAR10Configs): """ ## Configurations - We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the + We use [`CIFAR10Configs`](../../experiments/cifar10.html) which defines all the dataset related configurations, optimizer, and a training loop. """ + # [Transformer configurations](../configs.html#TransformerConfigs) + # to get [transformer layer](../models.html#TransformerLayer) transformer: TransformerConfigs + # Size of a patch patch_size: int = 4 - n_hidden: int = 2048 + # Size of the hidden layer in classification head + n_hidden_classification: int = 2048 + # Number of classes in the task n_classes: int = 10 @option(Configs.transformer) -def _transformer(c: Configs): +def _transformer(): + """ + Create transformer configs + """ return TransformerConfigs() @@ -42,11 +52,13 @@ def _vit(c: Configs): from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \ PatchEmbeddings + # Transformer size from [Transformer configurations](../configs.html#TransformerConfigs) d_model = c.transformer.d_model + # Create a vision transformer return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers, PatchEmbeddings(d_model, c.patch_size, 3), LearnedPositionalEmbeddings(d_model), - ClassificationHead(d_model, c.n_hidden, c.n_classes)).to(c.device) + ClassificationHead(d_model, c.n_hidden_classification, c.n_classes)).to(c.device) def main(): @@ -56,20 +68,20 @@ def main(): conf = Configs() # Load configurations experiment.configs(conf, { - 'device.cuda_device': 0, - - # 'optimizer.optimizer': 'Noam', - # 'optimizer.learning_rate': 1., + # Optimizer 'optimizer.optimizer': 'Adam', 'optimizer.learning_rate': 2.5e-4, - 'optimizer.d_model': 512, + # Transformer embedding size 'transformer.d_model': 512, + # Training epochs and batch size 'epochs': 1000, 'train_batch_size': 64, + # Augment CIFAR 10 images for training 'train_dataset': 'cifar10_train_augmented', + # Do not augment CIFAR 10 images for validation 'valid_dataset': 'cifar10_valid_no_augment', }) # Set model for saving/loading diff --git a/labml_nn/transformers/vit/readme.md b/labml_nn/transformers/vit/readme.md new file mode 100644 index 00000000..636ddb0c --- /dev/null +++ b/labml_nn/transformers/vit/readme.md @@ -0,0 +1,32 @@ +# [Vision Transformer (ViT)](https://nn.labml.ai/transformer/vit/index.html) + +This is a [PyTorch](https://pytorch.org) implementation of the paper +[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929). + +Vision transformer applies a pure transformer to images +without any convolution layers. +They split the image into patches and apply a transformer on patch embeddings. +[Patch embeddings](https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings) are generated by applying a simple linear transformation +to the flattened pixel values of the patch. +Then a standard transformer encoder is fed with the patch embeddings, along with a +classification token `[CLS]`. +The encoding on the `[CLS]` token is used to classify the image with an MLP. + +When feeding the transformer with the patches, learned positional embeddings are +added to the patch embeddings, because the patch embeddings do not have any information +about where that patch is from. +The positional embeddings are a set of vectors for each patch location that get trained +with gradient descent along with other parameters. + +ViTs perform well when they are pre-trained on large datasets. +The paper suggests pre-training them with an MLP classification head and +then using a single linear layer when fine-tuning. +The paper beats SOTA with a ViT pre-trained on a 300 million image dataset. +They also use higher resolution images during inference while keeping the +patch size the same. +The positional embeddings for new patch locations are calculated by interpolating +learning positional embeddings. + +Here's [an experiment](https://nn.labml.ai/transformer/vit/experiment.html) that trains ViT on CIFAR-10. +This doesn't do very well because it's trained on a small dataset. +It's a simple experiment that anyone can run and play with ViTs. diff --git a/readme.md b/readme.md index 21601c63..64b505dd 100644 --- a/readme.md +++ b/readme.md @@ -37,6 +37,7 @@ implementations almost weekly. * [Masked Language Model](https://nn.labml.ai/transformers/mlm/index.html) * [MLP-Mixer: An all-MLP Architecture for Vision](https://nn.labml.ai/transformers/mlp_mixer/index.html) * [Pay Attention to MLPs (gMLP)](https://nn.labml.ai/transformers/gmlp/index.html) +* [Vision Transformer (ViT)](https://nn.labml.ai/transformers/vit/index.html) #### ✨ [Recurrent Highway Networks](https://nn.labml.ai/recurrent_highway_networks/index.html) diff --git a/setup.py b/setup.py index fbc15d2d..4b41b521 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("readme.md", "r") as f: setuptools.setup( name='labml-nn', - version='0.4.102', + version='0.4.103', author="Varuna Jayasiri, Nipun Wijerathne", author_email="vpjayasiri@gmail.com, hnipun@gmail.com", description="A collection of PyTorch implementations of neural network architectures and layers.",