This is a PyTorch implementation of the paper Pay Attention to MLPs.
This paper introduces a Multilayer Perceptron (MLP) based architecture with gating, which they name gMLP. It consists of a stack of gMLP blocks.
Here is the training code for a gMLP model based autoregressive model.
19from typing import Optional
20
21import torch
22from torch import nnEach block does the following transformations to input embeddings where is the sequence length and is the dimensionality of the embeddings:
where and are learnable projection weights. is the Spacial Gating Unit defined below. Output dimensionality of will be half of . is an activation function such as GeLU.
25class GMLPBlock(nn.Module):d_model
 is the dimensionality () of  d_ffn
 is the dimensionality of  seq_len
 is the length of the token sequence ()46    def __init__(self, d_model: int, d_ffn: int, seq_len: int):52        super().__init__()Normalization layer fro Pre-Norm
54        self.norm = nn.LayerNorm([d_model])Activation function
56        self.activation = nn.GELU()Projection layer for
58        self.proj1 = nn.Linear(d_model, d_ffn)Spacial Gating Unit
60        self.sgu = SpacialGatingUnit(d_ffn, seq_len)Projection layer for
62        self.proj2 = nn.Linear(d_ffn // 2, d_model)Embedding size (required by Encoder. We use the encoder module from transformer architecture and plug gMLP block as a replacement for the Transformer Layer.
66        self.size = d_modelx
 is the input embedding tensor  of shape [seq_len, batch_size, d_model]
 mask
 is a boolean mask of shape [seq_len, seq_len, 1]
 that controls the visibility of tokens  among each other.68    def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):Keep a copy for shortcut connection
75        shortcut = xNormalize
77        x = self.norm(x)Projection and activation
79        z = self.activation(self.proj1(x))Spacial Gating Unit
81        z = self.sgu(z, mask)Final projection
83        z = self.proj2(z)Add the shortcut connection
86        return z + shortcut
where is a linear transformation along the sequence dimension, and is element-wise multiplication. is split into to parts of equal size and along the channel dimension (embedding dimension).
89class SpacialGatingUnit(nn.Module):d_z
 is the dimensionality of  seq_len
 is the sequence length99    def __init__(self, d_z: int, seq_len: int):104        super().__init__()Normalization layer before applying
106        self.norm = nn.LayerNorm([d_z // 2])Weight in .
The paper notes that it's important to initialize weights to small values and the bias to , so that during the initial training is close to identity (apart from the split).
111        self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)115        self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)z
 is the input  of shape [seq_len, batch_size, d_z]
 mask
 is is a boolean mask of shape [seq_len, seq_len, 1]
 that controls the visibility of tokens  among each other. The last dimension of size 1
 is the batch, which we have in other transformer  implementations and was left for compatibility.117    def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):Get sequence length
126        seq_len = z.shape[0]Split into and
128        z1, z2 = torch.chunk(z, 2, dim=-1)Check mask
131        if mask is not None:mask
 has shape [seq_len_q, seq_len_k, batch_size]
. The batch dimension should be of size 1
 because this implementation supports only same mask for all samples in the batch. 
135            assert mask.shape[0] == 1 or mask.shape[0] == seq_len
136            assert mask.shape[1] == seq_lenHere we only support the same mask for all samples
138            assert mask.shape[2] == 1Remove the batch dimension
140            mask = mask[:, :, 0]Normalize before
143        z2 = self.norm(z2)Get the weight matrix; truncate if larger than seq_len
 
145        weight = self.weight[:seq_len, :seq_len]150        if mask is not None:
151            weight = weight * mask154        z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]157        return z1 * z2