mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
📚 glu variants
This commit is contained in:
@ -19,6 +19,14 @@ from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPosit
|
||||
|
||||
|
||||
class FeedForwardConfigs(BaseConfigs):
|
||||
"""
|
||||
<a id="FFN">
|
||||
## FFN Configurations
|
||||
</a>
|
||||
|
||||
Creates a Position-wise FeedForward Network defined in
|
||||
[`feed_forward.py`](feed_forward.html).
|
||||
"""
|
||||
# Position-wise feedforward layer
|
||||
ffn: FeedForward
|
||||
# Number of features in the embedding
|
||||
@ -44,7 +52,9 @@ class FeedForwardConfigs(BaseConfigs):
|
||||
@option(FeedForwardConfigs.activation, 'ReLU')
|
||||
def _ffn_activation_relu():
|
||||
"""
|
||||
ReLU activation
|
||||
### ReLU activation
|
||||
|
||||
$$\max(0, x)$$
|
||||
"""
|
||||
return nn.ReLU()
|
||||
|
||||
@ -52,7 +62,11 @@ def _ffn_activation_relu():
|
||||
@option(FeedForwardConfigs.activation, 'GELU')
|
||||
def _ffn_activation_gelu():
|
||||
"""
|
||||
GELU activation
|
||||
### GELU activation
|
||||
|
||||
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
|
||||
|
||||
It was introduced in paper [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415).
|
||||
"""
|
||||
return nn.GELU()
|
||||
|
||||
@ -60,7 +74,7 @@ def _ffn_activation_gelu():
|
||||
@option(FeedForwardConfigs.ffn, 'default')
|
||||
def _feed_forward(c: FeedForwardConfigs):
|
||||
"""
|
||||
Create feedforward layer
|
||||
Initialize a [feed forward network](feed_forward.html)
|
||||
"""
|
||||
return FeedForward(c.d_model, c.d_ff,
|
||||
dropout=c.dropout,
|
||||
@ -70,7 +84,14 @@ def _feed_forward(c: FeedForwardConfigs):
|
||||
bias2=c.bias2,
|
||||
bias_gate=c.bias_gate)
|
||||
|
||||
# ## GLU Variants
|
||||
# These are variants with gated hidden layers for the FFN
|
||||
# as introduced in paper [GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202).
|
||||
# We have omitted the bias terms as specified in the paper.
|
||||
|
||||
# ### FFN with Gated Linear Units
|
||||
#
|
||||
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
|
||||
aggregate(FeedForwardConfigs.glu_variant, 'GLU',
|
||||
(FeedForwardConfigs.is_gated, True),
|
||||
(FeedForwardConfigs.bias1, False),
|
||||
@ -78,24 +99,40 @@ aggregate(FeedForwardConfigs.glu_variant, 'GLU',
|
||||
(FeedForwardConfigs.bias_gate, False),
|
||||
(FeedForwardConfigs.activation, nn.Sigmoid()))
|
||||
|
||||
# ### FFN with Bilinear hidden layer
|
||||
#
|
||||
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
|
||||
aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
|
||||
(FeedForwardConfigs.is_gated, True),
|
||||
(FeedForwardConfigs.bias1, False),
|
||||
(FeedForwardConfigs.bias2, False),
|
||||
(FeedForwardConfigs.bias_gate, False),
|
||||
(FeedForwardConfigs.activation, nn.Identity()))
|
||||
|
||||
# ### FFN with ReLU gate
|
||||
#
|
||||
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
|
||||
aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
|
||||
(FeedForwardConfigs.is_gated, True),
|
||||
(FeedForwardConfigs.bias1, False),
|
||||
(FeedForwardConfigs.bias2, False),
|
||||
(FeedForwardConfigs.bias_gate, False),
|
||||
(FeedForwardConfigs.activation, nn.ReLU()))
|
||||
|
||||
# ### FFN with GELU gate
|
||||
#
|
||||
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
|
||||
aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
|
||||
(FeedForwardConfigs.is_gated, True),
|
||||
(FeedForwardConfigs.bias1, False),
|
||||
(FeedForwardConfigs.bias2, False),
|
||||
(FeedForwardConfigs.bias_gate, False),
|
||||
(FeedForwardConfigs.activation, nn.GELU()))
|
||||
|
||||
# ### FFN with Swish gate
|
||||
#
|
||||
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
|
||||
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
|
||||
aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
|
||||
(FeedForwardConfigs.is_gated, True),
|
||||
(FeedForwardConfigs.bias1, False),
|
||||
@ -236,7 +273,7 @@ def _generator(c: TransformerConfigs):
|
||||
return Generator(c.n_tgt_vocab, c.d_model)
|
||||
|
||||
|
||||
# ## Positional Embeddings
|
||||
# ### Fixed Positional Embeddings
|
||||
@option(TransformerConfigs.src_embed, 'fixed_pos')
|
||||
def _src_embed_with_positional(c: TransformerConfigs):
|
||||
"""
|
||||
@ -253,7 +290,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
|
||||
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
||||
|
||||
|
||||
# ## Learned Positional Embeddings
|
||||
# ### Learned Positional Embeddings
|
||||
@option(TransformerConfigs.src_embed, 'learned_pos')
|
||||
def _src_embed_with_learned_positional(c: TransformerConfigs):
|
||||
"""
|
||||
@ -270,7 +307,7 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
|
||||
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
||||
|
||||
|
||||
# ## No Positional Embeddings
|
||||
# ### No Positional Embeddings
|
||||
@option(TransformerConfigs.src_embed, 'no_pos')
|
||||
def _src_embed_without_positional(c: TransformerConfigs):
|
||||
"""
|
||||
|
||||
@ -21,6 +21,15 @@ where $W_1$, $W_2$, $b_1$ and $b_2$ are learnable parameters.
|
||||
Sometimes the
|
||||
GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU.
|
||||
$$x \Phi(x)$$ where $\Phi(x) = P(X \le x), X \sim \mathcal{N}(0,1)$
|
||||
|
||||
### Gated Linear Units
|
||||
|
||||
This is a generic implementation that supports different variants including
|
||||
[Gated Linear Units](https://arxiv.org/abs/2002.05202) (GLU).
|
||||
We have also implemented experiments on these:
|
||||
|
||||
* [experiment that uses `labml.configs`](glu_variants/experiment.html)
|
||||
* [simpler version from scratch](glu_variants/simple.html)
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -31,7 +40,7 @@ from labml_helpers.module import Module
|
||||
|
||||
class FeedForward(Module):
|
||||
"""
|
||||
## Position-wise feed-forward network (FFN) module
|
||||
## FFN module
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int,
|
||||
@ -51,19 +60,32 @@ class FeedForward(Module):
|
||||
* `bias_gate` specified whether the fully connected layer for the gate should have a learnable bias
|
||||
"""
|
||||
super().__init__()
|
||||
# Layer one parameterized by weight $W_1$ and bias $b_1$
|
||||
self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
|
||||
# Layer one parameterized by weight $W_1$ and bias $b_1$
|
||||
self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
|
||||
# Hidden layer dropout
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
# Activation function $f$
|
||||
self.activation = activation
|
||||
# Whether there is a gate
|
||||
self.is_gated = is_gated
|
||||
if is_gated:
|
||||
# If there is a gate the linear layer to transform inputs to
|
||||
# be multiplied by the gate, parameterized by weight $V$ and bias $c$
|
||||
self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
|
||||
|
||||
def __call__(self, x: torch.Tensor):
|
||||
# $f(x W_1 + b_1)$
|
||||
g = self.activation(self.layer1(x))
|
||||
# If gated, $f(x W_1 + b_1) \otimes (x V + b) $
|
||||
if self.is_gated:
|
||||
x = g * self.linear_v(x)
|
||||
# Otherwise
|
||||
else:
|
||||
x = g
|
||||
# Apply dropout
|
||||
x = self.dropout(x)
|
||||
# $(f(x W_1 + b_1) \otimes (x V + b)) W_2 + b_2$ or $f(x W_1 + b_1) W_2 + b_2$
|
||||
# depending on whether it is gated
|
||||
return self.layer2(x)
|
||||
|
||||
@ -6,9 +6,11 @@ summary: >
|
||||
for the position-wise feedforward network (FFN).
|
||||
---
|
||||
|
||||
# Train Autoregressive Transformer
|
||||
# Gated Linear Units and Variants
|
||||
|
||||
This trains a simple [transformer](../../) model for auto-regression.
|
||||
We try different variants for the [position-wise feedforward network](../feed_forward).
|
||||
The reusable & configurable are defined in [`configs.py`](configs.html).
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -72,7 +74,7 @@ def autoregressive_model(c: Configs):
|
||||
@option(Configs.transformer)
|
||||
def transformer_c(c: Configs):
|
||||
"""
|
||||
Initialize the configurable transformer encoder for our autoregressive model
|
||||
Initialize the [configurable transformer](../configs.html) encoder for our autoregressive model.
|
||||
"""
|
||||
tc = TransformerConfigs()
|
||||
tc.n_src_vocab = c.n_tokens
|
||||
@ -104,6 +106,9 @@ def main():
|
||||
'inner_iterations': 10,
|
||||
|
||||
# GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
|
||||
#
|
||||
# These are defined in the [configurable FFN](../configs.html#FFN)
|
||||
# implementation
|
||||
'transformer.ffn.glu_variant': 'Bilinear',
|
||||
|
||||
# Transformer configurations
|
||||
|
||||
@ -6,9 +6,13 @@ summary: >
|
||||
for the position-wise feedforward network (FFN).
|
||||
---
|
||||
|
||||
# Train Autoregressive Transformer
|
||||
# Gated Linear Units and Variants
|
||||
|
||||
This trains a simple [transformer](../../) model for auto-regression.
|
||||
We try different variants for the [position-wise feedforward network](../feed_forward).
|
||||
|
||||
*This is a simpler implementation that doesn't use [`labml.configs`](experiment.html) module.
|
||||
We decided to write a simpler implementation to make it easier readers who are not familiar.*
|
||||
"""
|
||||
import dataclasses
|
||||
|
||||
@ -56,6 +60,9 @@ class AutoregressiveModel(nn.Module):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Configs:
|
||||
"""
|
||||
### Configurations
|
||||
"""
|
||||
d_model: int = 512
|
||||
seq_len: int = 128
|
||||
batch_size: int = 32
|
||||
@ -69,71 +76,130 @@ class Configs:
|
||||
|
||||
|
||||
class TinyShakespeareDataset(Dataset):
|
||||
"""
|
||||
### Tiny Shakespeare Dataset
|
||||
"""
|
||||
|
||||
def __init__(self, seq_len: int):
|
||||
# Location of the text file
|
||||
path = lab.get_data_path() / 'tiny_shakespeare.txt'
|
||||
# Download the file
|
||||
download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
|
||||
# Read the downloaded file
|
||||
with open(str(path), 'r') as f:
|
||||
text = f.read()
|
||||
|
||||
# Extract the characters
|
||||
chars = list(set(text))
|
||||
# Character to id (integer) map
|
||||
self.stoi = {c: i for i, c in enumerate(chars)}
|
||||
# Id to character map
|
||||
self.itos = {i: c for i, c in enumerate(chars)}
|
||||
# Length of a training sample
|
||||
self.seq_len = seq_len
|
||||
# Data in the form of a tensor of ids
|
||||
self.data = self.text_to_i(text)
|
||||
|
||||
def text_to_i(self, text: str):
|
||||
"""
|
||||
Transform the text into a tensor of ids
|
||||
"""
|
||||
return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of samples in the dataset.
|
||||
|
||||
*This will read the dataset `seq_len` times in a single epoch.*
|
||||
"""
|
||||
return len(self.data) - self.seq_len - 1
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Return a sample
|
||||
"""
|
||||
return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
## Trainer
|
||||
"""
|
||||
|
||||
def __init__(self, configs: Configs):
|
||||
# Get the device
|
||||
self.device = torch.device('cpu')
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device('cuda:0')
|
||||
# Initialize the dataset
|
||||
self.dataset = TinyShakespeareDataset(configs.seq_len)
|
||||
self.dataloader = DataLoader(self.dataset, batch_size=configs.batch_size, collate_fn=transpose_batch,
|
||||
# Initialize the dataloader
|
||||
self.dataloader = DataLoader(self.dataset,
|
||||
batch_size=configs.batch_size,
|
||||
collate_fn=transpose_batch,
|
||||
shuffle=True)
|
||||
|
||||
# FFN with Gated Linear Unit
|
||||
# $$FFN_{GLU}(x)(x, W_1, V, W_2) = (\sigma(x W_1) \otimes x V) W_2$$
|
||||
if configs.glu_variant == 'GLU':
|
||||
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
|
||||
# FFN with Bilinear hidden layer
|
||||
# $$FFN_{Bilinear}(x)(x, W_1, V, W_2) = (x W_1 \otimes x V) W_2$$
|
||||
elif configs.glu_variant == 'Bilinear':
|
||||
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
|
||||
# FFN with ReLU gate
|
||||
# $$FFN_{ReGLU}(x)(x, W_1, V, W_2) = (\max(0, x W_1) \otimes x V) W_2$$
|
||||
elif configs.glu_variant == 'ReGLU':
|
||||
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
|
||||
# FFN with GELU gate
|
||||
# $$FFN_{GEGLU}(x)(x, W_1, V, W_2) = (\text{GELU}(x W_1) \otimes x V) W_2$$
|
||||
elif configs.glu_variant == 'GEGLU':
|
||||
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
|
||||
# FFN with Swish gate
|
||||
# $$FFN_{SwiGLU}(x)(x, W_1, V, W_2) = (\text{Swish}_1(x W_1) \otimes x V) W_2$$
|
||||
# where $\text{Swish}_\beta(x) = x \sigma(\beta x)$
|
||||
elif configs.glu_variant == 'SwiGLU':
|
||||
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
|
||||
# FFN with ReLU activation
|
||||
# $$FFN_{ReLU}(x)(x, W_1, W_2, b_1, b_2) = \text{ReLU}_1(x W_1 + b_1) W_2 + b_2$$
|
||||
elif configs.glu_variant == 'ReLU':
|
||||
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
|
||||
# FFN with ReLU activation
|
||||
# $$FFN_{GELU}(x)(x, W_1, W_2, b_1, b_2) = \text{GELU}_1(x W_1 + b_1) W_2 + b_2$$
|
||||
elif configs.glu_variant == 'GELU':
|
||||
ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
|
||||
else:
|
||||
raise ValueError(f'Unknown variant {configs.glu_variant}')
|
||||
|
||||
# Number of different characters
|
||||
n_chars = len(self.dataset.stoi)
|
||||
|
||||
# Initialize [Multi-Head Attention module](../mha.html)
|
||||
mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
|
||||
# Initialize the [Transformer Block](../models.html#TransformerLayer)
|
||||
transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
|
||||
feed_forward=ffn, dropout_prob=configs.dropout)
|
||||
# Initialize the model with an
|
||||
# [embedding layer](../models.html#EmbeddingsWithPositionalEncoding)
|
||||
# (with fixed positional encoding)
|
||||
# [transformer encoder](../models.html#Encoder) and
|
||||
# a linear layer to generate logits.
|
||||
self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
|
||||
Encoder(TransformerLayer(
|
||||
d_model=configs.d_model,
|
||||
self_attn=MultiHeadAttention(configs.n_heads, configs.d_model,
|
||||
configs.dropout),
|
||||
src_attn=None,
|
||||
feed_forward=ffn,
|
||||
dropout_prob=configs.dropout
|
||||
), configs.n_layers),
|
||||
Encoder(transformer_layer, configs.n_layers),
|
||||
nn.Linear(configs.d_model, n_chars))
|
||||
|
||||
# Move the model to the current device
|
||||
self.model.to(self.device)
|
||||
|
||||
# Initialize [Noam optimizer](../../optimizers/noam.html)
|
||||
self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
|
||||
|
||||
# Cross-entropy loss
|
||||
self.loss_func = nn.CrossEntropyLoss()
|
||||
# Number of training epochs;
|
||||
# *note that our dataset definition repeats the data `seq_len` times in a single epoch
|
||||
self.epochs = configs.epochs
|
||||
# Gradient clipping norm
|
||||
self.grad_norm_clip = configs.grad_norm_clip
|
||||
|
||||
# Set tracker configurations
|
||||
@ -166,18 +232,28 @@ class Trainer:
|
||||
logger.log(log)
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
### Train the model
|
||||
"""
|
||||
|
||||
# Loop for the given number of epochs
|
||||
for _ in monit.loop(self.epochs):
|
||||
# Iterate over the minibatches
|
||||
for i, batch in monit.enum('Train', self.dataloader):
|
||||
# Move data to the device
|
||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
# Set tracker step, as the number of characters trained on
|
||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||
|
||||
# Set model state to training
|
||||
self.model.train()
|
||||
# Evaluate the model
|
||||
output = self.model(data)
|
||||
|
||||
# Calculate and log loss
|
||||
# Calculate loss
|
||||
loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
|
||||
# Log the loss
|
||||
tracker.add("loss.train", loss)
|
||||
|
||||
# Calculate gradients
|
||||
@ -186,12 +262,13 @@ class Trainer:
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
|
||||
# Take optimizer step
|
||||
self.optimizer.step()
|
||||
# Log the model parameters and gradients on last batch of every epoch
|
||||
# Log the model parameters and gradients
|
||||
if (i + 1) % 100 == 0:
|
||||
tracker.add('model', self.model)
|
||||
# Clear the gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Generate a sample
|
||||
if (i + 1) % 100 == 0:
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
@ -201,6 +278,7 @@ class Trainer:
|
||||
if (i + 1) % 10 == 0:
|
||||
tracker.save()
|
||||
|
||||
# Save the model
|
||||
experiment.save_checkpoint()
|
||||
|
||||
|
||||
@ -212,12 +290,14 @@ def main():
|
||||
# Load configurations
|
||||
experiment.configs(dataclasses.asdict(configs))
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(configs)
|
||||
# Set models for training and loading
|
||||
experiment.add_pytorch_models({'model': trainer.model})
|
||||
|
||||
# Start the experiment
|
||||
with experiment.start():
|
||||
# `TrainValidConfigs.run`
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user