mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	model & positional encodings annotations
This commit is contained in:
		| @ -1,6 +1,4 @@ | |||||||
| """ | """ | ||||||
| <a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a> |  | ||||||
|  |  | ||||||
| # Transformers | # Transformers | ||||||
|  |  | ||||||
| * [Multi-head attention](mha.html) | * [Multi-head attention](mha.html) | ||||||
|  | |||||||
| @ -1,6 +1,4 @@ | |||||||
| """ | """ | ||||||
| <a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a> |  | ||||||
|  |  | ||||||
| # Multi-Headed Attention | # Multi-Headed Attention | ||||||
|  |  | ||||||
| The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) | The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) | ||||||
|  | |||||||
| @ -11,6 +11,9 @@ from .positional_encoding import get_positional_encoding | |||||||
|  |  | ||||||
|  |  | ||||||
| class EmbeddingsWithPositionalEncoding(Module): | class EmbeddingsWithPositionalEncoding(Module): | ||||||
|  |     """ | ||||||
|  |     ## Embed tokenas and add [fixed positional encoding](positional_encoding.html) | ||||||
|  |     """ | ||||||
|     def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): |     def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.linear = nn.Embedding(n_vocab, d_model) |         self.linear = nn.Embedding(n_vocab, d_model) | ||||||
| @ -23,6 +26,9 @@ class EmbeddingsWithPositionalEncoding(Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class EmbeddingsWithLearnedPositionalEncoding(Module): | class EmbeddingsWithLearnedPositionalEncoding(Module): | ||||||
|  |     """ | ||||||
|  |     ## Embed tokenas and add parameterized positional encodings | ||||||
|  |     """ | ||||||
|     def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): |     def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.linear = nn.Embedding(n_vocab, d_model) |         self.linear = nn.Embedding(n_vocab, d_model) | ||||||
| @ -35,6 +41,9 @@ class EmbeddingsWithLearnedPositionalEncoding(Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class FeedForward(Module): | class FeedForward(Module): | ||||||
|  |     """ | ||||||
|  |     ## Position-wise feed-forward network with hidden layer | ||||||
|  |     """ | ||||||
|     def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): |     def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.layer1 = nn.Linear(d_model, d_ff) |         self.layer1 = nn.Linear(d_model, d_ff) | ||||||
| @ -49,6 +58,20 @@ class FeedForward(Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class TransformerLayer(Module): | class TransformerLayer(Module): | ||||||
|  |     """ | ||||||
|  |     ## Transformer Layer | ||||||
|  |  | ||||||
|  |     This can act as a encoder layer or a decoder layer. | ||||||
|  |  | ||||||
|  |     🗒 Some implementations, including the paper seem to have differences | ||||||
|  |     in where the layer-normalization is done. | ||||||
|  |     Here we do a layer normalization before attention and feed-forward networks, | ||||||
|  |     and add the original residual vectors. | ||||||
|  |     Alternative is to do a layer normalzation after adding the residuals. | ||||||
|  |     But we found this to be less stable when training. | ||||||
|  |     We found a detailed discussion about this in paper | ||||||
|  |      [On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745). | ||||||
|  |     """ | ||||||
|     def __init__(self, *, |     def __init__(self, *, | ||||||
|                  d_model: int, |                  d_model: int, | ||||||
|                  self_attn: MultiHeadAttention, |                  self_attn: MultiHeadAttention, | ||||||
| @ -71,47 +94,77 @@ class TransformerLayer(Module): | |||||||
|                  mask: torch.Tensor, |                  mask: torch.Tensor, | ||||||
|                  src: torch.Tensor = None, |                  src: torch.Tensor = None, | ||||||
|                  src_mask: torch.Tensor = None): |                  src_mask: torch.Tensor = None): | ||||||
|  |         # Normalize the vectors before doing self attention | ||||||
|         z = self.norm_self_attn(x) |         z = self.norm_self_attn(x) | ||||||
|         attn_self = self.self_attn(query=z, key=z, value=z, mask=mask) |         # Run through self attention, i.e. keys and values are from self | ||||||
|         x = x + self.dropout(attn_self) |         self_attn = self.self_attn(query=z, key=z, value=z, mask=mask) | ||||||
|  |         # Add the self attention results | ||||||
|  |         x = x + self.dropout(self_attn) | ||||||
|  |  | ||||||
|  |         # If a source is provided, get results from attention to source. | ||||||
|  |         # This is when you have a decoder layer that pays attention to  | ||||||
|  |         # encoder outputs | ||||||
|         if src is not None: |         if src is not None: | ||||||
|  |             # Normalize vectors | ||||||
|             z = self.norm_src_attn(x) |             z = self.norm_src_attn(x) | ||||||
|  |             # Attention to source. i.e. keys and values are from source | ||||||
|             attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask) |             attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask) | ||||||
|  |             # Add the source attention results | ||||||
|             x = x + self.dropout(attn_src) |             x = x + self.dropout(attn_src) | ||||||
|  |  | ||||||
|  |         # Normalize for feed-forward | ||||||
|         z = self.norm_ff(x) |         z = self.norm_ff(x) | ||||||
|  |         # Pass through the feed-forward network | ||||||
|         ff = self.feed_forward(z) |         ff = self.feed_forward(z) | ||||||
|  |         # Add the feed-forward results back | ||||||
|         x = x + self.dropout(ff) |         x = x + self.dropout(ff) | ||||||
|  |  | ||||||
|         return x |         return x | ||||||
|  |  | ||||||
|  |  | ||||||
| class Encoder(Module): | class Encoder(Module): | ||||||
|  |     """ | ||||||
|  |     ## Transformer Encoder | ||||||
|  |     """ | ||||||
|     def __init__(self, layer: TransformerLayer, n_layers: int): |     def __init__(self, layer: TransformerLayer, n_layers: int): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |         # Make copies of the transformer layer | ||||||
|         self.layers = clone_module_list(layer, n_layers) |         self.layers = clone_module_list(layer, n_layers) | ||||||
|         self.norm = nn.LayerNorm([layer.size]) |         self.norm = nn.LayerNorm([layer.size]) | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor, mask: torch.Tensor): |     def __call__(self, x: torch.Tensor, mask: torch.Tensor): | ||||||
|  |         # Run through each transformer layer | ||||||
|         for layer in self.layers: |         for layer in self.layers: | ||||||
|             x = layer(x=x, mask=mask) |             x = layer(x=x, mask=mask) | ||||||
|  |         # Finally, normalize the vectors | ||||||
|         return self.norm(x) |         return self.norm(x) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Decoder(Module): | class Decoder(Module): | ||||||
|  |     """ | ||||||
|  |     ## Transformer Decoder | ||||||
|  |     """ | ||||||
|     def __init__(self, layer: TransformerLayer, n_layers: int): |     def __init__(self, layer: TransformerLayer, n_layers: int): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  |         # Make copies of the transformer layer | ||||||
|         self.layers = clone_module_list(layer, n_layers) |         self.layers = clone_module_list(layer, n_layers) | ||||||
|         self.norm = nn.LayerNorm([layer.size]) |         self.norm = nn.LayerNorm([layer.size]) | ||||||
|  |  | ||||||
|     def __call__(self, x, memory, src_mask, tgt_mask): |     def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): | ||||||
|  |         # Run through each transformer layer | ||||||
|         for layer in self.layers: |         for layer in self.layers: | ||||||
|             x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask) |             x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask) | ||||||
|  |         # Finally, normalize the vectors | ||||||
|         return self.norm(x) |         return self.norm(x) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Generator(Module): | class Generator(Module): | ||||||
|  |     """ | ||||||
|  |     ## Generator | ||||||
|  |  | ||||||
|  |     This predicts the tokens and gives the lof softmaxes of those. | ||||||
|  |     You don't need this if you are using `nn.CrossEntropyLoss`. | ||||||
|  |     """ | ||||||
|     def __init__(self, n_vocab: int, d_model: int): |     def __init__(self, n_vocab: int, d_model: int): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.projection = nn.Linear(d_model, n_vocab) |         self.projection = nn.Linear(d_model, n_vocab) | ||||||
| @ -121,6 +174,9 @@ class Generator(Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class EncoderDecoder(Module): | class EncoderDecoder(Module): | ||||||
|  |     """ | ||||||
|  |     ## Combined Encoder-Decoder | ||||||
|  |     """ | ||||||
|     def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module): |     def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.encoder = encoder |         self.encoder = encoder | ||||||
| @ -135,10 +191,11 @@ class EncoderDecoder(Module): | |||||||
|             if p.dim() > 1: |             if p.dim() > 1: | ||||||
|                 nn.init.xavier_uniform_(p) |                 nn.init.xavier_uniform_(p) | ||||||
|  |  | ||||||
|     def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, |     def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): | ||||||
|                  tgt_mask: torch.Tensor): |         # Runs the source through encoder | ||||||
|         return self.decode(self.encode(src, src_mask), src_mask, |         enc = self.encode(src, src_mask) | ||||||
|                            tgt, tgt_mask) |         # Run encodings and targets through decoder | ||||||
|  |         return self.decode(enc, src_mask, tgt, tgt_mask) | ||||||
|  |  | ||||||
|     def encode(self, src: torch.Tensor, src_mask: torch.Tensor): |     def encode(self, src: torch.Tensor, src_mask: torch.Tensor): | ||||||
|         return self.encoder(self.src_embed(src), src_mask) |         return self.encoder(self.src_embed(src), src_mask) | ||||||
|  | |||||||
| @ -1,3 +1,18 @@ | |||||||
|  | """ | ||||||
|  | # Fixed Positional Encodings | ||||||
|  |  | ||||||
|  | The positional encoding encodes the position along the sequence into | ||||||
|  |  a vector of size `d_model`. | ||||||
|  |  | ||||||
|  | \begin{align} | ||||||
|  | PE_{p,2i} &= sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) \\ | ||||||
|  | PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) | ||||||
|  | \end{align} | ||||||
|  |  | ||||||
|  | Where $1 \leq 2i, 2i + 1 \leq d_{model}$ are the feature indexes in the encoding, | ||||||
|  | and $p$ is the position. | ||||||
|  | """ | ||||||
|  |  | ||||||
| import math | import math | ||||||
|  |  | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
| @ -23,12 +38,20 @@ class PositionalEncoding(Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_positional_encoding(d_model: int, max_len: int = 5000): | def get_positional_encoding(d_model: int, max_len: int = 5000): | ||||||
|  |     # Empty encodings vectors | ||||||
|     encodings = torch.zeros(max_len, d_model) |     encodings = torch.zeros(max_len, d_model) | ||||||
|  |     # Position indexes | ||||||
|     position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |     position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | ||||||
|  |     # $2 * i$ | ||||||
|     two_i = torch.arange(0, d_model, 2, dtype=torch.float32) |     two_i = torch.arange(0, d_model, 2, dtype=torch.float32) | ||||||
|  |     # $10000^{\frac{2i}{d_{model}}$ | ||||||
|     div_term = torch.exp(two_i * -(math.log(10000.0) / d_model)) |     div_term = torch.exp(two_i * -(math.log(10000.0) / d_model)) | ||||||
|  |     # $PE_{p,2i} = sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$ | ||||||
|     encodings[:, 0::2] = torch.sin(position * div_term) |     encodings[:, 0::2] = torch.sin(position * div_term) | ||||||
|  |     # $PE_{p,2i + 1} = cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$ | ||||||
|     encodings[:, 1::2] = torch.cos(position * div_term) |     encodings[:, 1::2] = torch.cos(position * div_term) | ||||||
|  |  | ||||||
|  |     # Add batch dimension | ||||||
|     encodings = encodings.unsqueeze(1).requires_grad_(False) |     encodings = encodings.unsqueeze(1).requires_grad_(False) | ||||||
|  |  | ||||||
|     return encodings |     return encodings | ||||||
|  | |||||||
| @ -1,6 +1,4 @@ | |||||||
| """ | """ | ||||||
| <a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a> |  | ||||||
|  |  | ||||||
| # Relative Multi-head Attention | # Relative Multi-head Attention | ||||||
|  |  | ||||||
| This is an implementation of  | This is an implementation of  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri