mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
ResNet (#68)
This commit is contained in:
@ -11,21 +11,71 @@ from typing import List
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from labml import lab
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.experiments.mnist import MNISTConfigs
|
||||
|
||||
|
||||
class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
This extends from CIFAR 10 dataset configurations from
|
||||
[`labml_helpers`](https://github.com/labmlai/labml/tree/master/helpers)
|
||||
and [`MNISTConfigs`](mnist.html).
|
||||
"""
|
||||
# Use CIFAR10 dataset by default
|
||||
dataset_name: str = 'CIFAR10'
|
||||
|
||||
|
||||
@option(CIFAR10Configs.train_dataset)
|
||||
def cifar10_train_augmented():
|
||||
"""
|
||||
### Augmented CIFAR 10 train dataset
|
||||
"""
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import transforms
|
||||
return CIFAR10(str(lab.get_data_path()),
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
# Pad and crop
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
# Random horizontal flip
|
||||
transforms.RandomHorizontalFlip(),
|
||||
#
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
]))
|
||||
|
||||
|
||||
@option(CIFAR10Configs.valid_dataset)
|
||||
def cifar10_valid_no_augment():
|
||||
"""
|
||||
### Non-augmented CIFAR 10 validation dataset
|
||||
"""
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import transforms
|
||||
return CIFAR10(str(lab.get_data_path()),
|
||||
train=False,
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
]))
|
||||
|
||||
|
||||
class CIFAR10VGGModel(Module):
|
||||
"""
|
||||
### VGG model for CIFAR-10 classification
|
||||
"""
|
||||
|
||||
def conv_block(self, in_channels, out_channels) -> nn.Module:
|
||||
"""
|
||||
Convolution and activation combined
|
||||
"""
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
@ -81,7 +81,8 @@ def _weight_decay(c: OptimizerConfigs):
|
||||
|
||||
@option(OptimizerConfigs.optimizer, 'SGD')
|
||||
def _sgd_optimizer(c: OptimizerConfigs):
|
||||
return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum)
|
||||
return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum,
|
||||
weight_decay=c.weight_decay)
|
||||
|
||||
|
||||
@option(OptimizerConfigs.optimizer, 'Adam')
|
||||
|
||||
325
labml_nn/resnet/__init__.py
Normal file
325
labml_nn/resnet/__init__.py
Normal file
@ -0,0 +1,325 @@
|
||||
"""
|
||||
---
|
||||
title: Deep Residual Learning for Image Recognition (ResNet)
|
||||
summary: >
|
||||
A PyTorch implementation/tutorial of Deep Residual Learning for Image Recognition (ResNet).
|
||||
---
|
||||
|
||||
# Deep Residual Learning for Image Recognition (ResNet)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation of the paper
|
||||
[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
|
||||
|
||||
ResNets train layers as residual functions to overcome the
|
||||
*degradation problem*.
|
||||
The degradation problem is the accuracy of deep neural networks degrading when
|
||||
the number of layers becomes very high.
|
||||
The accuracy increases as the number of layers increase, then saturates,
|
||||
and then starts to degrade.
|
||||
|
||||
The paper argues that deeper models should perform at least as well as shallower
|
||||
models because the extra layers can just learn to perform an identity mapping.
|
||||
|
||||
## Residual Learning
|
||||
|
||||
If $\mathcal{H}(x)$ is the mapping that needs to be learned by a few layers,
|
||||
they train the residual function
|
||||
|
||||
$$\mathcal{F}(x) = \mathcal{H}(x) - x$$
|
||||
|
||||
instead. And the original function becomes $\mathcal{F}(x) + x$.
|
||||
|
||||
In this case, learning identity mapping for $\mathcal{H}(x)$ is
|
||||
equivalent to learning $\mathcal{F}(x)$ to be $0$, which is easier to
|
||||
learn.
|
||||
|
||||
In the parameterized form this can be written as,
|
||||
|
||||
$$\mathcal{F}(x, \{W_i\}) + x$$
|
||||
|
||||
and when the feature map sizes of $\mathcal{F}(x, {W_i})$ and $x$ are different
|
||||
the paper suggests doing a linear projection, with learned weights $W_s$.
|
||||
|
||||
$$\mathcal{F}(x, \{W_i\}) + W_s x$$
|
||||
|
||||
Paper experimented with zero padding instead of linear projections and found linear projections
|
||||
to work better. Also when the feature map sizes match they found identity mapping
|
||||
to be better than linear projections.
|
||||
|
||||
$\mathcal{F}$ should have more than one layer, otherwise the sum $\mathcal{F}(x, \{W_i\}) + W_s x$
|
||||
also won't have non-linearities and will be like a linear layer.
|
||||
|
||||
Here is [the training code](experiment.html) for training a ResNet on CIFAR-10.
|
||||
|
||||
[](https://app.labml.ai/run/fc5ad600e4af11ebbafd23b8665193c1)
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class ShortcutProjection(Module):
|
||||
"""
|
||||
## Linear projections for shortcut connection
|
||||
|
||||
This does the $W_s x$ projection described above.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, stride: int):
|
||||
"""
|
||||
* `in_channels` is the number of channels in $x$
|
||||
* `out_channels` is the number of channels in $\mathcal{F}(x, \{W_i\})$
|
||||
* `stride` is the stride length in the convolution operation for $F$.
|
||||
We do the same stride on the shortcut connection, to match the feature-map size.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Convolution layer for linear projection $W_s x$
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
|
||||
# Paper suggests adding batch normalization after each convolution operation
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
# Convolution and batch normalization
|
||||
return self.bn(self.conv(x))
|
||||
|
||||
|
||||
class ResidualBlock(Module):
|
||||
"""
|
||||
<a id="residual_block"></a>
|
||||
## Residual Block
|
||||
|
||||
This implements the residual block described in the paper.
|
||||
It has two $3 \times 3$ convolution layers.
|
||||
|
||||

|
||||
|
||||
The first convolution layer maps from `in_channels` to `out_channels`,
|
||||
where the `out_channels` is higher than `in_channels` when we reduce the
|
||||
feature map size with a stride length greater than $1$.
|
||||
|
||||
The second convolution layer maps from `out_channels` to `out_channels` and
|
||||
always has a stride length of 1.
|
||||
|
||||
Both convolution layers are followed by batch normalization.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, stride: int):
|
||||
"""
|
||||
* `in_channels` is the number of channels in $x$
|
||||
* `out_channels` is the number of output channels
|
||||
* `stride` is the stride length in the convolution operation.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# First $3 \times 3$ convolution layer, this maps to `out_channels`
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
|
||||
# Batch normalization after the first convolution
|
||||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||
# First activation function (ReLU)
|
||||
self.act1 = nn.ReLU()
|
||||
|
||||
# Second $3 \times 3$ convolution layer
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
# Batch normalization after the second convolution
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# Shortcut connection should be a projection if the stride length is not $1$
|
||||
# of if the number of channels change
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
# Projection $W_s x$
|
||||
self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
|
||||
else:
|
||||
# Identity $x$
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
# Second activation function (ReLU) (after adding the shortcut)
|
||||
self.act2 = nn.ReLU()
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the input of shape `[batch_size, in_channels, height, width]`
|
||||
"""
|
||||
# Get the shortcut connection
|
||||
shortcut = self.shortcut(x)
|
||||
# First convolution and activation
|
||||
x = self.act1(self.bn1(self.conv1(x)))
|
||||
# Second convolution
|
||||
x = self.bn2(self.conv2(x))
|
||||
# Activation function after adding the shortcut
|
||||
return self.act2(x + shortcut)
|
||||
|
||||
|
||||
class BottleneckResidualBlock(Module):
|
||||
"""
|
||||
<a id="bottleneck_residual_block"></a>
|
||||
## Bottleneck Residual Block
|
||||
|
||||
This implements the bottleneck block described in the paper.
|
||||
It has $1 \times 1$, $3 \times 3$, and $1 \times 1$ convolution layers.
|
||||
|
||||

|
||||
|
||||
The first convolution layer maps from `in_channels` to `bottleneck_channels` with a $1x1$
|
||||
convolution,
|
||||
where the `bottleneck_channels` is lower than `in_channels`.
|
||||
|
||||
The second $3x3$ convolution layer maps from `bottleneck_channels` to `bottleneck_channels`.
|
||||
This can have a stride length greater than $1$ when we want to compress the
|
||||
feature map size.
|
||||
|
||||
The third, final $1x1$ convolution layer maps to `out_channels`.
|
||||
`out_channels` is higher than `in_channels` if the stride length is greater than $1$;
|
||||
otherwise, $out_channels$ is equal to `in_channels`.
|
||||
|
||||
`bottleneck_channels` is less than `in_channels` and the $3x3$ convolution is performed
|
||||
on this shrunk space (hence the bottleneck). The two $1x1$ convolution decreases and increases
|
||||
the number of channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, bottleneck_channels: int, out_channels: int, stride: int):
|
||||
"""
|
||||
* `in_channels` is the number of channels in $x$
|
||||
* `bottleneck_channels` is the number of channels for the $3x3$ convlution
|
||||
* `out_channels` is the number of output channels
|
||||
* `stride` is the stride length in the $3x3$ convolution operation.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# First $1 \times 1$ convolution layer, this maps to `bottleneck_channels`
|
||||
self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1)
|
||||
# Batch normalization after the first convolution
|
||||
self.bn1 = nn.BatchNorm2d(bottleneck_channels)
|
||||
# First activation function (ReLU)
|
||||
self.act1 = nn.ReLU()
|
||||
|
||||
# Second $3 \times 3$ convolution layer
|
||||
self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1)
|
||||
# Batch normalization after the second convolution
|
||||
self.bn2 = nn.BatchNorm2d(bottleneck_channels)
|
||||
# Second activation function (ReLU)
|
||||
self.act2 = nn.ReLU()
|
||||
|
||||
# Third $1 \times 1$ convolution layer, this maps to `out_channels`.
|
||||
self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1)
|
||||
# Batch normalization after the second convolution
|
||||
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# Shortcut connection should be a projection if the stride length is not $1$
|
||||
# of if the number of channels change
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
# Projection $W_s x$
|
||||
self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
|
||||
else:
|
||||
# Identity $x$
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
# Second activation function (ReLU) (after adding the shortcut)
|
||||
self.act3 = nn.ReLU()
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the input of shape `[batch_size, in_channels, height, width]`
|
||||
"""
|
||||
# Get the shortcut connection
|
||||
shortcut = self.shortcut(x)
|
||||
# First convolution and activation
|
||||
x = self.act1(self.bn1(self.conv1(x)))
|
||||
# Second convolution and activation
|
||||
x = self.act2(self.bn2(self.conv2(x)))
|
||||
# Third convolution
|
||||
x = self.bn3(self.conv3(x))
|
||||
# Activation function after adding the shortcut
|
||||
return self.act3(x + shortcut)
|
||||
|
||||
|
||||
class ResNetBase(Module):
|
||||
"""
|
||||
## ResNet Model
|
||||
|
||||
This is a the base of the resnet model without
|
||||
the final linear layer and softmax for classification.
|
||||
|
||||
The resnet is made of stacked [residual blocks](#residual_block) or
|
||||
[bottleneck residual blocks](#bottleneck_residual_block).
|
||||
The feature map size is halved after a few blocks with a block of stride length $2$.
|
||||
The number of channels is increased when the feature map size is reduced.
|
||||
Finally the feature map is average pooled to get a vector representation.
|
||||
"""
|
||||
|
||||
def __init__(self, n_blocks: List[int], n_channels: List[int],
|
||||
bottlenecks: Optional[List[int]] = None,
|
||||
img_channels: int = 3, first_kernel_size: int = 7):
|
||||
"""
|
||||
* `n_blocks` is a list of of number of blocks for each feature map size.
|
||||
* `n_channels` is the number of channels for each feature map size.
|
||||
* `bottlenecks` is the number of channels the bottlenecks.
|
||||
If this is `None`, [residual blocks](#residual_block) are used.
|
||||
* `img_channels` is the number of channels in the input.
|
||||
* `first_kernel_size` is the kernel size of the initial convolution layer
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Number of blocks and number of channels for each feature map size
|
||||
assert len(n_blocks) == len(n_channels)
|
||||
# If [bottleneck residual blocks](#bottleneck_residual_block) are used,
|
||||
# the number of channels in bottlenecks should be provided for each feature map size
|
||||
assert bottlenecks is None or len(bottlenecks) == len(n_channels)
|
||||
|
||||
# Initial convolution layer maps from `img_channels` to number of channels in the first
|
||||
# residual block (`n_channels[0]`)
|
||||
self.conv = nn.Conv2d(img_channels, n_channels[0],
|
||||
kernel_size=first_kernel_size, stride=1, padding=first_kernel_size // 2)
|
||||
# Batch norm after initial convolution
|
||||
self.bn = nn.BatchNorm2d(n_channels[0])
|
||||
|
||||
# List of blocks
|
||||
blocks = []
|
||||
# Number of channels from previous layer (or block)
|
||||
prev_channels = n_channels[0]
|
||||
# Loop through each feature map size
|
||||
for i, channels in enumerate(n_channels):
|
||||
# The first block for the new feature map size, will have a stride length of $2$
|
||||
# except fro the very first block
|
||||
stride = 2 if len(blocks) == 0 else 1
|
||||
|
||||
if bottlenecks is None:
|
||||
# [residual blocks](#residual_block) that maps from `prev_channels` to `channels`
|
||||
blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
|
||||
else:
|
||||
# [bottleneck residual blocks](#bottleneck_residual_block)
|
||||
# that maps from `prev_channels` to `channels`
|
||||
blocks.append(BottleneckResidualBlock(prev_channels, bottlenecks[i], channels,
|
||||
stride=stride))
|
||||
|
||||
# Change the number of channels
|
||||
prev_channels = channels
|
||||
# Add rest of the blocks - no change in feature map size or channels
|
||||
for _ in range(n_blocks[i] - 1):
|
||||
if bottlenecks is None:
|
||||
# [residual blocks](#residual_block)
|
||||
blocks.append(ResidualBlock(channels, channels, stride=1))
|
||||
else:
|
||||
# [bottleneck residual blocks](#bottleneck_residual_block)
|
||||
blocks.append(BottleneckResidualBlock(channels, bottlenecks[i], channels, stride=1))
|
||||
|
||||
# Stack the blocks
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` has shape `[batch_size, img_channels, height, width]`
|
||||
"""
|
||||
|
||||
# Initial convolution and batch normalization
|
||||
x = self.bn(self.conv(x))
|
||||
# Residual (or bottleneck) blocks
|
||||
x = self.blocks(x)
|
||||
# Change `x` from shape `[batch_size, channels, h, w]` to `[batch_size, channels, h * w]`
|
||||
x = x.view(x.shape[0], x.shape[1], -1)
|
||||
# Global average pooling
|
||||
return x.mean(dim=-1)
|
||||
84
labml_nn/resnet/experiment.py
Normal file
84
labml_nn/resnet/experiment.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
---
|
||||
title: Train a ResNet on CIFAR 10
|
||||
summary: >
|
||||
Train a ResNet on CIFAR 10
|
||||
---
|
||||
|
||||
# Train a [ResNet](index.html) on CIFAR 10
|
||||
|
||||
[](https://app.labml.ai/run/fc5ad600e4af11ebbafd23b8665193c1)
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml_nn.experiments.cifar10 import CIFAR10Configs
|
||||
from labml_nn.resnet import ResNetBase
|
||||
|
||||
|
||||
class Configs(CIFAR10Configs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
|
||||
dataset related configurations, optimizer, and a training loop.
|
||||
"""
|
||||
|
||||
# Number fo blocks for each feature map size
|
||||
n_blocks: List[int] = [3, 3, 3]
|
||||
# Number of channels for each feature map size
|
||||
n_channels: List[int] = [16, 32, 64]
|
||||
# Bottleneck sizes
|
||||
bottlenecks: Optional[List[int]] = None
|
||||
# Kernel size of the initial convolution layer
|
||||
first_kernel_size: int = 7
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def _resnet(c: Configs):
|
||||
"""
|
||||
### Create model
|
||||
"""
|
||||
# [ResNet](index.html)
|
||||
base = ResNetBase(c.n_blocks, c.n_channels, c.bottlenecks, img_channels=3, first_kernel_size=c.first_kernel_size)
|
||||
# Linear layer for classification
|
||||
classification = nn.Linear(c.n_channels[-1], 10)
|
||||
|
||||
# Stack them
|
||||
model = nn.Sequential(base, classification)
|
||||
# Move the model to the device
|
||||
return model.to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='resnet', comment='cifar10')
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'bottlenecks': [8, 16, 16],
|
||||
'n_blocks': [6, 6, 6],
|
||||
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 2.5e-4,
|
||||
|
||||
'epochs': 500,
|
||||
'train_batch_size': 256,
|
||||
|
||||
'train_dataset': 'cifar10_train_augmented',
|
||||
'valid_dataset': 'cifar10_valid_no_augment',
|
||||
})
|
||||
# Set model for saving/loading
|
||||
experiment.add_pytorch_models({'model': conf.model})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
90
labml_nn/transformers/vit/__init__.py
Normal file
90
labml_nn/transformers/vit/__init__.py
Normal file
@ -0,0 +1,90 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
from labml_nn.transformers import TransformerLayer
|
||||
from labml_nn.utils import clone_module_list
|
||||
|
||||
|
||||
class PatchEmbeddings(Module):
|
||||
"""
|
||||
<a id="PatchEmbeddings">
|
||||
## Embed patches
|
||||
</a>
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, patch_size: int, in_channels: int):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
"""
|
||||
x has shape `[batch_size, channels, height, width]`
|
||||
"""
|
||||
x = self.conv(x)
|
||||
bs, c, h, w = x.shape
|
||||
x = x.permute(2, 3, 0, 1)
|
||||
x = x.view(h * w, bs, c)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LearnedPositionalEmbeddings(Module):
|
||||
"""
|
||||
<a id="LearnedPositionalEmbeddings">
|
||||
## Add parameterized positional encodings
|
||||
</a>
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, max_len: int = 5_000):
|
||||
super().__init__()
|
||||
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
pe = self.positional_encodings[x.shape[0]]
|
||||
return x + pe
|
||||
|
||||
|
||||
class ClassificationHead(Module):
|
||||
def __init__(self, d_model: int, n_hidden: int, n_classes: int):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm([d_model])
|
||||
self.linear1 = nn.Linear(d_model, n_hidden)
|
||||
self.act = nn.ReLU()
|
||||
self.linear2 = nn.Linear(n_hidden, n_classes)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
x = self.ln(x)
|
||||
x = self.act(self.linear1(x))
|
||||
x = self.linear2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(Module):
|
||||
def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
|
||||
patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
|
||||
classification: ClassificationHead):
|
||||
super().__init__()
|
||||
# Make copies of the transformer layer
|
||||
self.classification = classification
|
||||
self.pos_emb = pos_emb
|
||||
self.patch_emb = patch_emb
|
||||
self.transformer_layers = clone_module_list(transformer_layer, n_layers)
|
||||
|
||||
self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.patch_emb(x)
|
||||
x = self.pos_emb(x)
|
||||
cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
|
||||
x = torch.cat([cls_token_emb, x])
|
||||
for layer in self.transformer_layers:
|
||||
x = layer(x=x, mask=None)
|
||||
|
||||
x = x[0]
|
||||
|
||||
x = self.classification(x)
|
||||
|
||||
return x
|
||||
84
labml_nn/transformers/vit/experiment.py
Normal file
84
labml_nn/transformers/vit/experiment.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
---
|
||||
title: Train a ViT on CIFAR 10
|
||||
summary: >
|
||||
Train a ViT on CIFAR 10
|
||||
---
|
||||
|
||||
# Train a ViT on CIFAR 10
|
||||
"""
|
||||
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml_nn.experiments.cifar10 import CIFAR10Configs
|
||||
from labml_nn.transformers import TransformerConfigs
|
||||
|
||||
|
||||
class Configs(CIFAR10Configs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
|
||||
dataset related configurations, optimizer, and a training loop.
|
||||
"""
|
||||
|
||||
transformer: TransformerConfigs
|
||||
|
||||
patch_size: int = 4
|
||||
n_hidden: int = 2048
|
||||
n_classes: int = 10
|
||||
|
||||
|
||||
@option(Configs.transformer)
|
||||
def _transformer(c: Configs):
|
||||
return TransformerConfigs()
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
def _vit(c: Configs):
|
||||
"""
|
||||
### Create model
|
||||
"""
|
||||
from labml_nn.transformers.vit import VisionTransformer, LearnedPositionalEmbeddings, ClassificationHead, \
|
||||
PatchEmbeddings
|
||||
|
||||
d_model = c.transformer.d_model
|
||||
return VisionTransformer(c.transformer.encoder_layer, c.transformer.n_layers,
|
||||
PatchEmbeddings(d_model, c.patch_size, 3),
|
||||
LearnedPositionalEmbeddings(d_model),
|
||||
ClassificationHead(d_model, c.n_hidden, c.n_classes)).to(c.device)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='ViT', comment='cifar10')
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
experiment.configs(conf, {
|
||||
'device.cuda_device': 0,
|
||||
|
||||
# 'optimizer.optimizer': 'Noam',
|
||||
# 'optimizer.learning_rate': 1.,
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 2.5e-4,
|
||||
'optimizer.d_model': 512,
|
||||
|
||||
'transformer.d_model': 512,
|
||||
|
||||
'epochs': 1000,
|
||||
'train_batch_size': 64,
|
||||
|
||||
'train_dataset': 'cifar10_train_augmented',
|
||||
'valid_dataset': 'cifar10_valid_no_augment',
|
||||
})
|
||||
# Set model for saving/loading
|
||||
experiment.add_pytorch_models({'model': conf.model})
|
||||
# Start the experiment and run the training loop
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user