Files
Varuna Jayasiri cf565bcc1d cleanup
2024-06-18 11:09:02 +05:30

225 lines
7.5 KiB
Python

"""
---
title: Transformer Encoder and Decoder Models
summary: >
These are PyTorch implementations of Transformer based encoder and decoder models,
as well as other related modules.
---
# Transformer Encoder and Decoder Models
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/basic/autoregressive_experiment.ipynb)
"""
import math
import torch
import torch.nn as nn
from labml_nn.utils import clone_module_list
from .feed_forward import FeedForward
from .mha import MultiHeadAttention
from .positional_encoding import get_positional_encoding
class EmbeddingsWithPositionalEncoding(nn.Module):
"""
<a id="EmbeddingsWithPositionalEncoding"></a>
## Embed tokens and add [fixed positional encoding](positional_encoding.html)
"""
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
self.d_model = d_model
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
def forward(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
return self.linear(x) * math.sqrt(self.d_model) + pe
class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
"""
<a id="EmbeddingsWithLearnedPositionalEncoding"></a>
## Embed tokens and add parameterized positional encodings
"""
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
self.d_model = d_model
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
def forward(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]]
return self.linear(x) * math.sqrt(self.d_model) + pe
class TransformerLayer(nn.Module):
"""
<a id="TransformerLayer"></a>
## Transformer Layer
This can act as an encoder layer or a decoder layer. We use pre-norm.
"""
def __init__(self, *,
d_model: int,
self_attn: MultiHeadAttention,
src_attn: MultiHeadAttention = None,
feed_forward: FeedForward,
dropout_prob: float):
"""
* `d_model` is the token embedding size
* `self_attn` is the self attention module
* `src_attn` is the source attention module (when this is used in a decoder)
* `feed_forward` is the feed forward module
* `dropout_prob` is the probability of dropping out after self attention and FFN
"""
super().__init__()
self.size = d_model
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.dropout = nn.Dropout(dropout_prob)
self.norm_self_attn = nn.LayerNorm([d_model])
if self.src_attn is not None:
self.norm_src_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
# Whether to save input to the feed forward layer
self.is_save_ff_input = False
def forward(self, *,
x: torch.Tensor,
mask: torch.Tensor,
src: torch.Tensor = None,
src_mask: torch.Tensor = None):
# Normalize the vectors before doing self attention
z = self.norm_self_attn(x)
# Run through self attention, i.e. keys and values are from self
self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
# Add the self attention results
x = x + self.dropout(self_attn)
# If a source is provided, get results from attention to source.
# This is when you have a decoder layer that pays attention to
# encoder outputs
if src is not None:
# Normalize vectors
z = self.norm_src_attn(x)
# Attention to source. i.e. keys and values are from source
attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
# Add the source attention results
x = x + self.dropout(attn_src)
# Normalize for feed-forward
z = self.norm_ff(x)
# Save the input to the feed forward layer if specified
if self.is_save_ff_input:
self.ff_input = z.clone()
# Pass through the feed-forward network
ff = self.feed_forward(z)
# Add the feed-forward results back
x = x + self.dropout(ff)
return x
class Encoder(nn.Module):
"""
<a id="Encoder"></a>
## Transformer Encoder
"""
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
# Make copies of the transformer layer
self.layers = clone_module_list(layer, n_layers)
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
def forward(self, x: torch.Tensor, mask: torch.Tensor):
# Run through each transformer layer
for layer in self.layers:
x = layer(x=x, mask=mask)
# Finally, normalize the vectors
return self.norm(x)
class Decoder(nn.Module):
"""
<a id="Decoder"></a>
## Transformer Decoder
"""
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
# Make copies of the transformer layer
self.layers = clone_module_list(layer, n_layers)
# Final normalization layer
self.norm = nn.LayerNorm([layer.size])
def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
# Run through each transformer layer
for layer in self.layers:
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
# Finally, normalize the vectors
return self.norm(x)
class Generator(nn.Module):
"""
<a id="Generator"></a>
## Generator
This predicts the tokens and gives the lof softmax of those.
You don't need this if you are using `nn.CrossEntropyLoss`.
"""
def __init__(self, n_vocab: int, d_model: int):
super().__init__()
self.projection = nn.Linear(d_model, n_vocab)
def forward(self, x):
return self.projection(x)
class EncoderDecoder(nn.Module):
"""
<a id="EncoderDecoder"></a>
## Combined Encoder-Decoder
"""
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
# Run the source through encoder
enc = self.encode(src, src_mask)
# Run encodings and targets through decoder
return self.decode(enc, src_mask, tgt, tgt_mask)
def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)