mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 18:27:03 +08:00
model & positional encodings annotations
This commit is contained in:
@ -1,6 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
<a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a>
|
|
||||||
|
|
||||||
# Transformers
|
# Transformers
|
||||||
|
|
||||||
* [Multi-head attention](mha.html)
|
* [Multi-head attention](mha.html)
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
<a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a>
|
|
||||||
|
|
||||||
# Multi-Headed Attention
|
# Multi-Headed Attention
|
||||||
|
|
||||||
The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html)
|
The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html)
|
||||||
|
|||||||
@ -11,6 +11,9 @@ from .positional_encoding import get_positional_encoding
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithPositionalEncoding(Module):
|
class EmbeddingsWithPositionalEncoding(Module):
|
||||||
|
"""
|
||||||
|
## Embed tokenas and add [fixed positional encoding](positional_encoding.html)
|
||||||
|
"""
|
||||||
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
|
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = nn.Embedding(n_vocab, d_model)
|
self.linear = nn.Embedding(n_vocab, d_model)
|
||||||
@ -23,6 +26,9 @@ class EmbeddingsWithPositionalEncoding(Module):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithLearnedPositionalEncoding(Module):
|
class EmbeddingsWithLearnedPositionalEncoding(Module):
|
||||||
|
"""
|
||||||
|
## Embed tokenas and add parameterized positional encodings
|
||||||
|
"""
|
||||||
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
|
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = nn.Embedding(n_vocab, d_model)
|
self.linear = nn.Embedding(n_vocab, d_model)
|
||||||
@ -35,6 +41,9 @@ class EmbeddingsWithLearnedPositionalEncoding(Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(Module):
|
class FeedForward(Module):
|
||||||
|
"""
|
||||||
|
## Position-wise feed-forward network with hidden layer
|
||||||
|
"""
|
||||||
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer1 = nn.Linear(d_model, d_ff)
|
self.layer1 = nn.Linear(d_model, d_ff)
|
||||||
@ -49,6 +58,20 @@ class FeedForward(Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerLayer(Module):
|
class TransformerLayer(Module):
|
||||||
|
"""
|
||||||
|
## Transformer Layer
|
||||||
|
|
||||||
|
This can act as a encoder layer or a decoder layer.
|
||||||
|
|
||||||
|
🗒 Some implementations, including the paper seem to have differences
|
||||||
|
in where the layer-normalization is done.
|
||||||
|
Here we do a layer normalization before attention and feed-forward networks,
|
||||||
|
and add the original residual vectors.
|
||||||
|
Alternative is to do a layer normalzation after adding the residuals.
|
||||||
|
But we found this to be less stable when training.
|
||||||
|
We found a detailed discussion about this in paper
|
||||||
|
[On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745).
|
||||||
|
"""
|
||||||
def __init__(self, *,
|
def __init__(self, *,
|
||||||
d_model: int,
|
d_model: int,
|
||||||
self_attn: MultiHeadAttention,
|
self_attn: MultiHeadAttention,
|
||||||
@ -71,47 +94,77 @@ class TransformerLayer(Module):
|
|||||||
mask: torch.Tensor,
|
mask: torch.Tensor,
|
||||||
src: torch.Tensor = None,
|
src: torch.Tensor = None,
|
||||||
src_mask: torch.Tensor = None):
|
src_mask: torch.Tensor = None):
|
||||||
|
# Normalize the vectors before doing self attention
|
||||||
z = self.norm_self_attn(x)
|
z = self.norm_self_attn(x)
|
||||||
attn_self = self.self_attn(query=z, key=z, value=z, mask=mask)
|
# Run through self attention, i.e. keys and values are from self
|
||||||
x = x + self.dropout(attn_self)
|
self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
|
||||||
|
# Add the self attention results
|
||||||
|
x = x + self.dropout(self_attn)
|
||||||
|
|
||||||
|
# If a source is provided, get results from attention to source.
|
||||||
|
# This is when you have a decoder layer that pays attention to
|
||||||
|
# encoder outputs
|
||||||
if src is not None:
|
if src is not None:
|
||||||
|
# Normalize vectors
|
||||||
z = self.norm_src_attn(x)
|
z = self.norm_src_attn(x)
|
||||||
|
# Attention to source. i.e. keys and values are from source
|
||||||
attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
|
attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
|
||||||
|
# Add the source attention results
|
||||||
x = x + self.dropout(attn_src)
|
x = x + self.dropout(attn_src)
|
||||||
|
|
||||||
|
# Normalize for feed-forward
|
||||||
z = self.norm_ff(x)
|
z = self.norm_ff(x)
|
||||||
|
# Pass through the feed-forward network
|
||||||
ff = self.feed_forward(z)
|
ff = self.feed_forward(z)
|
||||||
|
# Add the feed-forward results back
|
||||||
x = x + self.dropout(ff)
|
x = x + self.dropout(ff)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Encoder(Module):
|
class Encoder(Module):
|
||||||
|
"""
|
||||||
|
## Transformer Encoder
|
||||||
|
"""
|
||||||
def __init__(self, layer: TransformerLayer, n_layers: int):
|
def __init__(self, layer: TransformerLayer, n_layers: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Make copies of the transformer layer
|
||||||
self.layers = clone_module_list(layer, n_layers)
|
self.layers = clone_module_list(layer, n_layers)
|
||||||
self.norm = nn.LayerNorm([layer.size])
|
self.norm = nn.LayerNorm([layer.size])
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor, mask: torch.Tensor):
|
def __call__(self, x: torch.Tensor, mask: torch.Tensor):
|
||||||
|
# Run through each transformer layer
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x=x, mask=mask)
|
x = layer(x=x, mask=mask)
|
||||||
|
# Finally, normalize the vectors
|
||||||
return self.norm(x)
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
class Decoder(Module):
|
class Decoder(Module):
|
||||||
|
"""
|
||||||
|
## Transformer Decoder
|
||||||
|
"""
|
||||||
def __init__(self, layer: TransformerLayer, n_layers: int):
|
def __init__(self, layer: TransformerLayer, n_layers: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Make copies of the transformer layer
|
||||||
self.layers = clone_module_list(layer, n_layers)
|
self.layers = clone_module_list(layer, n_layers)
|
||||||
self.norm = nn.LayerNorm([layer.size])
|
self.norm = nn.LayerNorm([layer.size])
|
||||||
|
|
||||||
def __call__(self, x, memory, src_mask, tgt_mask):
|
def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
|
||||||
|
# Run through each transformer layer
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
|
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
|
||||||
|
# Finally, normalize the vectors
|
||||||
return self.norm(x)
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
class Generator(Module):
|
class Generator(Module):
|
||||||
|
"""
|
||||||
|
## Generator
|
||||||
|
|
||||||
|
This predicts the tokens and gives the lof softmaxes of those.
|
||||||
|
You don't need this if you are using `nn.CrossEntropyLoss`.
|
||||||
|
"""
|
||||||
def __init__(self, n_vocab: int, d_model: int):
|
def __init__(self, n_vocab: int, d_model: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.projection = nn.Linear(d_model, n_vocab)
|
self.projection = nn.Linear(d_model, n_vocab)
|
||||||
@ -121,6 +174,9 @@ class Generator(Module):
|
|||||||
|
|
||||||
|
|
||||||
class EncoderDecoder(Module):
|
class EncoderDecoder(Module):
|
||||||
|
"""
|
||||||
|
## Combined Encoder-Decoder
|
||||||
|
"""
|
||||||
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
|
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
@ -135,10 +191,11 @@ class EncoderDecoder(Module):
|
|||||||
if p.dim() > 1:
|
if p.dim() > 1:
|
||||||
nn.init.xavier_uniform_(p)
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor,
|
def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
|
||||||
tgt_mask: torch.Tensor):
|
# Runs the source through encoder
|
||||||
return self.decode(self.encode(src, src_mask), src_mask,
|
enc = self.encode(src, src_mask)
|
||||||
tgt, tgt_mask)
|
# Run encodings and targets through decoder
|
||||||
|
return self.decode(enc, src_mask, tgt, tgt_mask)
|
||||||
|
|
||||||
def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
|
def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
|
||||||
return self.encoder(self.src_embed(src), src_mask)
|
return self.encoder(self.src_embed(src), src_mask)
|
||||||
|
|||||||
@ -1,3 +1,18 @@
|
|||||||
|
"""
|
||||||
|
# Fixed Positional Encodings
|
||||||
|
|
||||||
|
The positional encoding encodes the position along the sequence into
|
||||||
|
a vector of size `d_model`.
|
||||||
|
|
||||||
|
\begin{align}
|
||||||
|
PE_{p,2i} &= sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) \\
|
||||||
|
PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)
|
||||||
|
\end{align}
|
||||||
|
|
||||||
|
Where $1 \leq 2i, 2i + 1 \leq d_{model}$ are the feature indexes in the encoding,
|
||||||
|
and $p$ is the position.
|
||||||
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -23,12 +38,20 @@ class PositionalEncoding(Module):
|
|||||||
|
|
||||||
|
|
||||||
def get_positional_encoding(d_model: int, max_len: int = 5000):
|
def get_positional_encoding(d_model: int, max_len: int = 5000):
|
||||||
|
# Empty encodings vectors
|
||||||
encodings = torch.zeros(max_len, d_model)
|
encodings = torch.zeros(max_len, d_model)
|
||||||
|
# Position indexes
|
||||||
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
||||||
|
# $2 * i$
|
||||||
two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
|
two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
|
||||||
|
# $10000^{\frac{2i}{d_{model}}$
|
||||||
div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
|
div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
|
||||||
|
# $PE_{p,2i} = sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$
|
||||||
encodings[:, 0::2] = torch.sin(position * div_term)
|
encodings[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
# $PE_{p,2i + 1} = cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$
|
||||||
encodings[:, 1::2] = torch.cos(position * div_term)
|
encodings[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
|
||||||
|
# Add batch dimension
|
||||||
encodings = encodings.unsqueeze(1).requires_grad_(False)
|
encodings = encodings.unsqueeze(1).requires_grad_(False)
|
||||||
|
|
||||||
return encodings
|
return encodings
|
||||||
|
|||||||
@ -1,6 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
<a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a>
|
|
||||||
|
|
||||||
# Relative Multi-head Attention
|
# Relative Multi-head Attention
|
||||||
|
|
||||||
This is an implementation of
|
This is an implementation of
|
||||||
|
|||||||
Reference in New Issue
Block a user