+ home + transformers + gmlp +
+ +This is an annotated PyTorch experiment to train a gMLP model.
+This is based on +training loop and configurations for a simple transformer auto-regressive NLP task.
+ +16from labml import experiment
+17from labml.configs import option
+18from labml_nn.transformers import TransformerConfigs
+19from labml_nn.transformers.basic.autoregressive_experiment import Configs as BasicAutoRegressionConfigs
+20from labml_nn.transformers.gmlp import GMLPBlock
This inherits from +training loop and configurations for a simple transformer auto-regressive NLP task.
+23class Configs(BasicAutoRegressionConfigs):
Transformer
+32 transformer: TransformerConfigs = 'gMLP'
gMLP Block
+34 gmlp: GMLPBlock
d_ffn
for gMLP projection layer
36 d_ffn: int = 2048
39@option(Configs.gmlp, 'gMLP')
+40def _gmlp_configs(c: Configs):
44 return GMLPBlock(c.d_model, c.d_ffn, c.seq_len)
47@option(Configs.transformer, 'gMLP')
+48def _transformer_configs(c: Configs):
We use our +configurable transformer implementation
+55 conf = TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
+57 conf.n_src_vocab = c.n_tokens
+58 conf.n_tgt_vocab = c.n_tokens
Set model size
+60 conf.d_model = c.d_model
Replace the encoder layer with a gMLP layer
+62 conf.encoder_layer = c.gmlp
+63
+64 return conf
67def main():
Create experiment
+69 experiment.create(name="gMLP")
Create configs
+71 conf = Configs()
Override configurations
+73 experiment.configs(conf, {
Use character level tokenizer
+75 'tokenizer': 'character',
Prompt separator is blank
+77 'prompt_separator': '',
Starting prompt for sampling
+79 'prompt': 'It is ',
Use Tiny Shakespeare dataset
+81 'text': 'tiny_shakespeare',
Use a context size of $256$
+84 'seq_len': 256,
Train for $128$ epochs
+86 'epochs': 128,
Batch size $32$
+88 'batch_size': 32,
Switch between training and validation for $10$ times +per epoch
+91 'inner_iterations': 10,
Model size
+94 'd_model': 512,
+95 'd_ffn': 2048,
Use Noam optimizer
+98 'optimizer.optimizer': 'Noam',
+99 'optimizer.learning_rate': 1.,
+100 })
Set models for saving and loading
+103 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
+106 with experiment.start():
Run training
+108 conf.run()
112if __name__ == '__main__':
+113 main()
+ home + transformers + gmlp +
+ +This is a PyTorch implementation of the paper +Pay Attention to MLPs.
+This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, +which they name gMLP. It consists of a stack of $L$ gMLP blocks.
+Here is the training code for a gMLP model based autoregressive model.
+ +21from typing import Optional
+22
+23import torch
+24from torch import nn
Each block does the following transformations to input embeddings +$X \in \mathbb{R}^{n \times d}$ where $n$ is the sequence length +and $d$ is the dimensionality of the embeddings:
++ +
+where $V$ and $U$ are learnable projection weights. +$s(\cdot)$ is the Spacial Gating Unit defined below. +Output dimensionality of $s(\cdot)$ will be half of $Z$. +$\sigma$ is an activation function such as +GeLU.
+27class GMLPBlock(nn.Module):
d_model
is the dimensionality ($d$) of $X$
+d_ffn
is the dimensionality of $Z$
+seq_len
is the length of the token sequence ($n$)
48 def __init__(self, d_model: int, d_ffn: int, seq_len: int):
54 super().__init__()
Normalization layer fro Pre-Norm
+56 self.norm = nn.LayerNorm([d_model])
Activation function $\sigma$
+58 self.activation = nn.GELU()
Projection layer for $Z = \sigma(XU)$
+60 self.proj1 = nn.Linear(d_model, d_ffn)
Spacial Gating Unit $s(\cdot)$
+62 self.sgu = SpacialGatingUnit(d_ffn, seq_len)
Projection layer for $Y = \tilde{Z}V$
+64 self.proj2 = nn.Linear(d_ffn // 2, d_model)
Embedding size (required by Encoder. +We use the encoder module from transformer architecture and plug +gMLP block as a replacement for the Transformer Layer.
+68 self.size = d_model
x
is the input embedding tensor $X$ of shape [seq_len, batch_size, d_model]
mask
is a boolean mask of shape [seq_len, seq_len, 1]
that controls the visibility of tokens
+ among each other.70 def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
Keep a copy for shortcut connection
+77 shortcut = x
Normalize $X$
+79 x = self.norm(x)
Projection and activation $Z = \sigma(XU)$
+81 z = self.activation(self.proj1(x))
Spacial Gating Unit $\tilde{Z} = s(Z)$
+83 z = self.sgu(z, mask)
Final projection $Y = \tilde{Z}V$
+85 z = self.proj2(z)
Add the shortcut connection
+88 return z + shortcut
+ +
+where $f_{W,b}(Z) = W Z + b$ is a linear transformation along the sequence dimension, +and $\odot$ is element-wise multiplication. +$Z$ is split into to parts of equal size $Z_1$ and $Z_2$ along the channel dimension (embedding dimension).
+91class SpacialGatingUnit(nn.Module):
d_z
is the dimensionality of $Z$seq_len
is the sequence length101 def __init__(self, d_z: int, seq_len: int):
106 super().__init__()
Normalization layer before applying $f_{W,b}(\cdot)$
+108 self.norm = nn.LayerNorm([d_z // 2])
Weight $W$ in $f_{W,b}(\cdot)$.
+The paper notes that it’s important to initialize weights to small values and the bias to $1$, +so that during the initial training $s(\cdot)$ is close to identity (apart from the split).
+113 self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)
Weight $b$ in $f_{W,b}(\cdot)$
+The paper notes that it’s important to initialize bias to $1$.
+117 self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
z
is the input $Z$ of shape [seq_len, batch_size, d_z]
mask
is is a boolean mask of shape [seq_len, seq_len, 1]
that controls the visibility of tokens
+ among each other. The last dimension of size 1
is the batch, which we have in other transformer
+ implementations and was left for compatibility.119 def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):
Get sequence length
+128 seq_len = z.shape[0]
Split $Z$ into $Z_1$ and $Z_2$
+130 z1, z2 = torch.chunk(z, 2, dim=-1)
Check mask
+133 if mask is not None:
mask
has shape [seq_len_q, seq_len_k, batch_size]
.
+The batch dimension should be of size 1
because this implementation supports
+only same mask for all samples in the batch.
137 assert mask.shape[0] == 1 or mask.shape[0] == seq_len
+138 assert mask.shape[1] == seq_len
Here we only support the same mask for all samples
+140 assert mask.shape[2] == 1
Remove the batch dimension
+142 mask = mask[:, :, 0]
Normalize $Z_2$ before $f_{W,b}(\cdot)$
+145 z2 = self.norm(z2)
Get the weight matrix; truncate if larger than seq_len
147 weight = self.weight[:seq_len, :seq_len]
Apply mask to the weights.
+If $W_{i,j}$ is $0$ then $f_{W,b}(Z_2)_i$ will not get any information +from token $j$.
+152 if mask is not None:
+153 weight = weight * mask
$f_{W,b}(Z_2) = W Z_2 + b$
+156 z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]
$Z_1 \odot f_{W,b}(Z_2)$
+159 return z1 * z2
+ home + transformers + gmlp +
+ +This is a PyTorch implementation of the paper +Pay Attention to MLPs.
+This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, +which they name gMLP. It consists of a stack of $L$ gMLP blocks.
+Here is the training code for a gMLP model based autoregressive model.
+ +