📚 glu variants

This commit is contained in:
Varuna Jayasiri
2021-01-26 16:54:23 +05:30
parent 20d2e27a3c
commit abe5caba6f
10 changed files with 1390 additions and 552 deletions

View File

@ -19,6 +19,14 @@ from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPosit
class FeedForwardConfigs(BaseConfigs):
"""
<a id="FFN">
## FFN Configurations
</a>
Creates a Position-wise FeedForward Network defined in
[`feed_forward.py`](feed_forward.html).
"""
# Position-wise feedforward layer
ffn: FeedForward
# Number of features in the embedding
@ -44,7 +52,9 @@ class FeedForwardConfigs(BaseConfigs):
@option(FeedForwardConfigs.activation, 'ReLU')
def _ffn_activation_relu():
"""
ReLU activation
### ReLU activation
$$\max(0, x)$$
"""
return nn.ReLU()
@ -52,7 +62,11 @@ def _ffn_activation_relu():
@option(FeedForwardConfigs.activation, 'GELU')
def _ffn_activation_gelu():
"""
GELU activation
### GELU activation
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
"""
return nn.GELU()
@ -60,7 +74,7 @@ def _ffn_activation_gelu():
@option(FeedForwardConfigs.ffn, 'default')
def _feed_forward(c: FeedForwardConfigs):
"""
Create feedforward layer
Initialize a [feed forward network](feed_forward.html)
"""
return FeedForward(c.d_model, c.d_ff,
dropout=c.dropout,
@ -70,7 +84,14 @@ def _feed_forward(c: FeedForwardConfigs):
bias2=c.bias2,
bias_gate=c.bias_gate)
# ## GLU Variants
# These are variants with gated hidden layers for the FFN
# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).
# We have omitted the bias terms as specified in the paper.
# ### FFN with Gated Linear Units
#
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'GLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
@ -78,24 +99,40 @@ aggregate(FeedForwardConfigs.glu_variant, 'GLU',
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.Sigmoid()))
# ### FFN with Bilinear hidden layer
#
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
(FeedForwardConfigs.bias2, False),
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.Identity()))
# ### FFN with ReLU gate
#
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
(FeedForwardConfigs.bias2, False),
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.ReLU()))
# ### FFN with GELU gate
#
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
(FeedForwardConfigs.bias2, False),
(FeedForwardConfigs.bias_gate, False),
(FeedForwardConfigs.activation, nn.GELU()))
# ### FFN with Swish gate
#
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
(FeedForwardConfigs.is_gated, True),
(FeedForwardConfigs.bias1, False),
@ -236,7 +273,7 @@ def _generator(c: TransformerConfigs):
return Generator(c.n_tgt_vocab, c.d_model)
# ## Positional Embeddings
# ### Fixed Positional Embeddings
@option(TransformerConfigs.src_embed, 'fixed_pos')
def _src_embed_with_positional(c: TransformerConfigs):
"""
@ -253,7 +290,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
# ## Learned Positional Embeddings
# ### Learned Positional Embeddings
@option(TransformerConfigs.src_embed, 'learned_pos')
def _src_embed_with_learned_positional(c: TransformerConfigs):
"""
@ -270,7 +307,7 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
# ## No Positional Embeddings
# ### No Positional Embeddings
@option(TransformerConfigs.src_embed, 'no_pos')
def _src_embed_without_positional(c: TransformerConfigs):
"""

View File

@ -21,6 +21,15 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.
Sometimes the
GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
### Gated Linear Units
This is a generic implementation that supports different variants including
[Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU).
We have also implemented experiments on these:
* [experiment that uses `labml.configs`](glu_variants/experiment.html)
* [simpler version from scratch](glu_variants/simple.html)
"""
import torch
@ -31,7 +40,7 @@ from labml_helpers.module import Module
class FeedForward(Module):
"""
## Position-wise feed-forward network (FFN) module
## FFN module
"""
def __init__(self, d_model: int, d_ff: int,
@ -51,19 +60,32 @@ class FeedForward(Module):
* `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
"""
super().__init__()
# Layer one parameterized by weight $W_1$ and bias $b_1$
self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
# Layer one parameterized by weight $W_1$ and bias $b_1$
self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
# Hidden layer dropout
self.dropout = nn.Dropout(dropout)
# Activation function $f$
self.activation = activation
# Whether there is a gate
self.is_gated = is_gated
if is_gated:
# If there is a gate the linear layer to transform inputs to
# be multiplied by the gate, parameterized by weight $V$ and bias $c$
self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
def __call__(self, x: torch.Tensor):
# $f(x W_1 + b_1)$
g = self.activation(self.layer1(x))
# If gated, $f(x W_1 + b_1) \otimes (x V + b) $
if self.is_gated:
x = g * self.linear_v(x)
# Otherwise
else:
x = g
# Apply dropout
x = self.dropout(x)
# $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
# depending on whether it is gated
return self.layer2(x)

View File

@ -6,9 +6,11 @@ summary: >
for the position-wise feedforward network (FFN).
---
# Train Autoregressive Transformer
# Gated Linear Units and Variants
This trains a simple [transformer](../../) model for auto-regression.
We try different variants for the [position-wise feedforward network](../feed_forward).
The reusable & configurable are defined in [`configs.py`](configs.html).
"""
import torch
@ -72,7 +74,7 @@ def autoregressive_model(c: Configs):
@option(Configs.transformer)
def transformer_c(c: Configs):
"""
Initialize the configurable transformer encoder for our autoregressive model
Initialize the [configurable transformer](../configs.html) encoder for our autoregressive model.
"""
tc = TransformerConfigs()
tc.n_src_vocab = c.n_tokens
@ -104,6 +106,9 @@ def main():
'inner_iterations': 10,
# GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
#
# These are defined in the [configurable FFN](../configs.html#FFN)
# implementation
'transformer.ffn.glu_variant': 'Bilinear',
# Transformer configurations

View File

@ -6,9 +6,13 @@ summary: >
for the position-wise feedforward network (FFN).
---
# Train Autoregressive Transformer
# Gated Linear Units and Variants
This trains a simple [transformer](../../) model for auto-regression.
We try different variants for the [position-wise feedforward network](../feed_forward).
*This is a simpler implementation that doesn't use [`labml.configs`](experiment.html) module.
We decided to write a simpler implementation to make it easier readers who are not familiar.*
"""
import dataclasses
@ -56,6 +60,9 @@ class AutoregressiveModel(nn.Module):
@dataclasses.dataclass
class Configs:
"""
### Configurations
"""
d_model: int = 512
seq_len: int = 128
batch_size: int = 32
@ -69,71 +76,130 @@ class Configs:
class TinyShakespeareDataset(Dataset):
"""
### Tiny Shakespeare Dataset
"""
def __init__(self, seq_len: int):
# Location of the text file
path = lab.get_data_path() / 'tiny_shakespeare.txt'
# Download the file
download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
# Read the downloaded file
with open(str(path), 'r') as f:
text = f.read()
# Extract the characters
chars = list(set(text))
# Character to id (integer) map
self.stoi = {c: i for i, c in enumerate(chars)}
# Id to character map
self.itos = {i: c for i, c in enumerate(chars)}
# Length of a training sample
self.seq_len = seq_len
# Data in the form of a tensor of ids
self.data = self.text_to_i(text)
def text_to_i(self, text: str):
"""
Transform the text into a tensor of ids
"""
return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
def __len__(self):
"""
Number of samples in the dataset.
*This will read the dataset `seq_len` times in a single epoch.*
"""
return len(self.data) - self.seq_len - 1
def __getitem__(self, idx):
"""
Return a sample
"""
return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
class Trainer:
"""
## Trainer
"""
def __init__(self, configs: Configs):
# Get the device
self.device = torch.device('cpu')
if torch.cuda.is_available():
self.device = torch.device('cuda:0')
# Initialize the dataset
self.dataset = TinyShakespeareDataset(configs.seq_len)
self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
# Initialize the dataloader
self.dataloader = DataLoader(self.dataset,
batch_size=configs.batch_size,
collate_fn=transpose_batch,
shuffle=True)
# FFN with Gated Linear Unit
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
if configs.glu_variant == 'GLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
# FFN with Bilinear hidden layer
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
elif configs.glu_variant == 'Bilinear':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
# FFN with ReLU gate
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
elif configs.glu_variant == 'ReGLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
# FFN with GELU gate
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
elif configs.glu_variant == 'GEGLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
# FFN with Swish gate
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
elif configs.glu_variant == 'SwiGLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
# FFN with ReLU activation
# $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$
elif configs.glu_variant == 'ReLU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
# FFN with ReLU activation
# $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$
elif configs.glu_variant == 'GELU':
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
else:
raise ValueError(f'Unknown variant {configs.glu_variant}')
# Number of different characters
n_chars = len(self.dataset.stoi)
# Initialize [Multi-Head Attention module](../mha.html)
mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
# Initialize the [Transformer Block](../models.html#TransformerLayer)
transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
feed_forward=ffn, dropout_prob=configs.dropout)
# Initialize the model with an
# [embedding layer](../models.html#EmbeddingsWithPositionalEncoding)
# (with fixed positional encoding)
# [transformer encoder](../models.html#Encoder) and
# a linear layer to generate logits.
self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
Encoder(TransformerLayer(
d_model=configs.d_model,
self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
configs.dropout),
src_attn=None,
feed_forward=ffn,
dropout_prob=configs.dropout
), configs.n_layers),
Encoder(transformer_layer, configs.n_layers),
nn.Linear(configs.d_model, n_chars))
# Move the model to the current device
self.model.to(self.device)
# Initialize [Noam optimizer](../../optimizers/noam.html)
self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
# Cross-entropy loss
self.loss_func = nn.CrossEntropyLoss()
# Number of training epochs;
# *note that our dataset definition repeats the data `seq_len` times in a single epoch
self.epochs = configs.epochs
# Gradient clipping norm
self.grad_norm_clip = configs.grad_norm_clip
# Set tracker configurations
@ -166,18 +232,28 @@ class Trainer:
logger.log(log)
def train(self):
"""
### Train the model
"""
# Loop for the given number of epochs
for _ in monit.loop(self.epochs):
# Iterate over the minibatches
for i, batch in monit.enum('Train', self.dataloader):
# Move data to the device
data, target = batch[0].to(self.device), batch[1].to(self.device)
# Set tracker step, as the number of characters trained on
tracker.add_global_step(data.shape[0] * data.shape[1])
# Set model state to training
self.model.train()
# Evaluate the model
output = self.model(data)
# Calculate and log loss
# Calculate loss
loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
# Log the loss
tracker.add("loss.train", loss)
# Calculate gradients
@ -186,12 +262,13 @@ class Trainer:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
# Take optimizer step
self.optimizer.step()
# Log the model parameters and gradients on last batch of every epoch
# Log the model parameters and gradients
if (i + 1) % 100 == 0:
tracker.add('model', self.model)
# Clear the gradients
self.optimizer.zero_grad()
# Generate a sample
if (i + 1) % 100 == 0:
self.model.eval()
with torch.no_grad():
@ -201,6 +278,7 @@ class Trainer:
if (i + 1) % 10 == 0:
tracker.save()
# Save the model
experiment.save_checkpoint()
@ -212,12 +290,14 @@ def main():
# Load configurations
experiment.configs(dataclasses.asdict(configs))
# Create trainer
trainer = Trainer(configs)
# Set models for training and loading
experiment.add_pytorch_models({'model': trainer.model})
# Start the experiment
with experiment.start():
# `TrainValidConfigs.run`
# Train the model
trainer.train()