Files

219 lines
6.9 KiB
Python

"""
---
title: Fuzzy Tiling Activation Experiment
summary: >
Training a transformer with FTA in FFN on Tiny Shakespeare.
---
# [Fuzzy Tiling Activation](index.html) Experiment
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/activations/fta/experiment.ipynb)
Here we train a transformer that uses [Fuzzy Tiling Activation](index.html) in the
[Feed-Forward Network](../../transformers/feed_forward.html).
We use it for a language model and train it on Tiny Shakespeare dataset
for demonstration.
However, this is probably not the ideal task for FTA, and we
believe FTA is more suitable for modeling data with continuous variables.
"""
import copy
import torch
import torch.nn as nn
from labml import experiment
from labml.configs import option
from labml_nn.activations.fta import FTA
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import MultiHeadAttention, TransformerLayer
from labml_nn.transformers.utils import subsequent_mask
class FeedForwardFTA(nn.Module):
"""
## FFN module with [FTA](index.html) activation
"""
def __init__(self, d_model: int, d_ff: int,
activation: FTA,
dropout: float = 0.1):
"""
* `d_model` is the number of features in a token embedding
* `d_ff` is the number of features in the hidden layer of the FFN
* `activation` is FTA activation module
* `dropout` is dropout probability for the hidden layer
"""
super().__init__()
# Layer one parameterized by weight $W_1$ and bias $b_1$
self.layer1 = nn.Linear(d_model, d_ff)
# Layer two parameterized by weight $W_1$ and bias $b_1$
self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
# Hidden layer dropout
self.dropout = nn.Dropout(dropout)
# Activation function $f$
self.activation = activation
def forward(self, x: torch.Tensor):
# $f(x W_1 + b_1)$
x = self.activation(self.layer1(x))
# Apply dropout
x = self.dropout(x)
#
return self.layer2(x)
class AutoregressiveTransformer(nn.Module):
"""
## Auto-Regressive model
This is an autoregressive transformer model that uses Feed-Forward Networks with
(Fuzzy Tiling Activations)(index.html).
"""
def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
"""
:param n_tokens: is the number of tokens in the vocabulary
:param d_model: is the embedding size
:param n_layers: is the number of transformer layers
:param layer: is the layer. We use `n_layers` copies of this for the transformer.
"""
super().__init__()
# Transformer with `n_layers` layers
self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
# Token embedding layer
self.emb = nn.Embedding(n_tokens, d_model)
# Readout layer
self.readout = nn.Linear(d_model, n_tokens)
# The mask will be initialized on the first call
self.mask = None
def forward(self, x: torch.Tensor):
"""
:param x: are the input tokens of shape `[seq_len, batch_size]`
"""
# Create auto-regressive mask
if self.mask is None or self.mask.size(0) != len(x):
# Subsequent mask, will mask out tokens from seeing future tokens
self.mask = subsequent_mask(len(x)).to(x.device)
# Get the token embeddings
x = self.emb(x)
# Transformer encoder
for layer in self.transformer_layers:
x = layer(x=x, mask=self.mask)
# Get logits
x = self.readout(x)
# Return results
return x, None
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
This inherits from
[`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)
"""
# Model
model: AutoregressiveTransformer
# Number of layers
n_layers: int = 4
# $\alpha$ and $\beta$ for DeepNorm
deep_norm_alpha: float
deep_norm_beta: float
# Number of heads in the attention
n_heads: int = 4
# Embedding size
d_model: int = 256
# Size of each attention head
d_k: int = 16
# Feed forward layer size
d_ff: int = 256
# FTA
fta_lower_limit: float = -1.
fta_upper_limit: float = +1.
fta_delta: float = 0.2
fta_eta: float = 0.05
@option(Configs.model)
def _model(c: Configs):
"""
#### Initialize the model
"""
# Create FTA activation module
fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
# Create the transformer.
# We re-use [`TransformerLayer`](../../transformers/models.html#TransformerLayer) and
# [`MultiHeadAttention`](../../transformers/mha.html) implementations.
m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
TransformerLayer(d_model=c.d_model,
feed_forward=FeedForwardFTA(d_model=c.d_model,
d_ff=c.d_ff,
activation=fta,
dropout=0.1),
self_attn=MultiHeadAttention(c.n_heads, c.d_model,
dropout_prob=0.0),
dropout_prob=0.0))
# Move to the device
return m.to(c.device)
def main():
"""
#### Create and run the experiment
"""
# Create experiment
experiment.create(name="fta", writers={'screen', 'labml'})
# Create configs
conf = Configs()
# Override configurations
experiment.configs(conf, {
# Use character level tokenizer
'tokenizer': 'character',
# Prompt separator is blank
'prompt_separator': '',
# Starting prompt for sampling
'prompt': 'It is ',
# Use Tiny Shakespeare dataset
'text': 'tiny_shakespeare',
# Use a context size of $256$
'seq_len': 256,
# Train for 32 epochs
'epochs': 32,
# Batch size $16$
'batch_size': 16,
# Switch between training and validation for $10$ times per epoch
'inner_iterations': 10,
# Adam optimizer with no warmup
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 3e-4,
})
# Set model(s) for saving and loading
experiment.add_pytorch_models({'model': conf.model})
# Start the experiment
with experiment.start():
# Run training
conf.run()
#
if __name__ == '__main__':
main()