This is a PyTorch implementation of the paper Pay Attention to MLPs.
This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name gMLP. It consists of a stack of $L$ gMLP blocks.
Here is the training code for a gMLP model based autoregressive model.
21from typing import Optional
22
23import torch
24from torch import nn
Each block does the following transformations to input embeddings $X \in \mathbb{R}^{n \times d}$ where $n$ is the sequence length and $d$ is the dimensionality of the embeddings:
where $V$ and $U$ are learnable projection weights. $s(\cdot)$ is the Spacial Gating Unit defined below. Output dimensionality of $s(\cdot)$ will be half of $Z$. $\sigma$ is an activation function such as GeLU.
27class GMLPBlock(nn.Module):
d_model
is the dimensionality ($d$) of $X$
d_ffn
is the dimensionality of $Z$
seq_len
is the length of the token sequence ($n$)
48 def __init__(self, d_model: int, d_ffn: int, seq_len: int):
54 super().__init__()
Normalization layer fro Pre-Norm
56 self.norm = nn.LayerNorm([d_model])
Activation function $\sigma$
58 self.activation = nn.GELU()
Projection layer for $Z = \sigma(XU)$
60 self.proj1 = nn.Linear(d_model, d_ffn)
Spacial Gating Unit $s(\cdot)$
62 self.sgu = SpacialGatingUnit(d_ffn, seq_len)
Projection layer for $Y = \tilde{Z}V$
64 self.proj2 = nn.Linear(d_ffn // 2, d_model)
Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug gMLP block as a replacement for the Transformer Layer.
68 self.size = d_model
x
is the input embedding tensor $X$ of shape [seq_len, batch_size, d_model]
mask
is a boolean mask of shape [seq_len, seq_len, 1]
that controls the visibility of tokens
among each other.70 def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
Keep a copy for shortcut connection
77 shortcut = x
Normalize $X$
79 x = self.norm(x)
Projection and activation $Z = \sigma(XU)$
81 z = self.activation(self.proj1(x))
Spacial Gating Unit $\tilde{Z} = s(Z)$
83 z = self.sgu(z, mask)
Final projection $Y = \tilde{Z}V$
85 z = self.proj2(z)
Add the shortcut connection
88 return z + shortcut
where $f_{W,b}(Z) = W Z + b$ is a linear transformation along the sequence dimension, and $\odot$ is element-wise multiplication. $Z$ is split into to parts of equal size $Z_1$ and $Z_2$ along the channel dimension (embedding dimension).
91class SpacialGatingUnit(nn.Module):
d_z
is the dimensionality of $Z$seq_len
is the sequence length101 def __init__(self, d_z: int, seq_len: int):
106 super().__init__()
Normalization layer before applying $f_{W,b}(\cdot)$
108 self.norm = nn.LayerNorm([d_z // 2])
Weight $W$ in $f_{W,b}(\cdot)$.
The paper notes that it’s important to initialize weights to small values and the bias to $1$, so that during the initial training $s(\cdot)$ is close to identity (apart from the split).
113 self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)
117 self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
z
is the input $Z$ of shape [seq_len, batch_size, d_z]
mask
is is a boolean mask of shape [seq_len, seq_len, 1]
that controls the visibility of tokens
among each other. The last dimension of size 1
is the batch, which we have in other transformer
implementations and was left for compatibility.119 def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):
Get sequence length
128 seq_len = z.shape[0]
Split $Z$ into $Z_1$ and $Z_2$
130 z1, z2 = torch.chunk(z, 2, dim=-1)
Check mask
133 if mask is not None:
mask
has shape [seq_len_q, seq_len_k, batch_size]
.
The batch dimension should be of size 1
because this implementation supports
only same mask for all samples in the batch.
137 assert mask.shape[0] == 1 or mask.shape[0] == seq_len
138 assert mask.shape[1] == seq_len
Here we only support the same mask for all samples
140 assert mask.shape[2] == 1
Remove the batch dimension
142 mask = mask[:, :, 0]
Normalize $Z_2$ before $f_{W,b}(\cdot)$
145 z2 = self.norm(z2)
Get the weight matrix; truncate if larger than seq_len
147 weight = self.weight[:seq_len, :seq_len]
Apply mask to the weights.
If $W_{i,j}$ is $0$ then $f_{W,b}(Z_2)_i$ will not get any information from token $j$.
152 if mask is not None:
153 weight = weight * mask
$f_{W,b}(Z_2) = W Z_2 + b$
156 z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]
$Z_1 \odot f_{W,b}(Z_2)$
159 return z1 * z2