mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	model & positional encodings annotations
This commit is contained in:
		| @ -1,6 +1,4 @@ | ||||
| """ | ||||
| <a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a> | ||||
|  | ||||
| # Transformers | ||||
|  | ||||
| * [Multi-head attention](mha.html) | ||||
|  | ||||
| @ -1,6 +1,4 @@ | ||||
| """ | ||||
| <a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a> | ||||
|  | ||||
| # Multi-Headed Attention | ||||
|  | ||||
| The implementation is inspired from [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) | ||||
|  | ||||
| @ -11,6 +11,9 @@ from .positional_encoding import get_positional_encoding | ||||
|  | ||||
|  | ||||
| class EmbeddingsWithPositionalEncoding(Module): | ||||
|     """ | ||||
|     ## Embed tokenas and add [fixed positional encoding](positional_encoding.html) | ||||
|     """ | ||||
|     def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): | ||||
|         super().__init__() | ||||
|         self.linear = nn.Embedding(n_vocab, d_model) | ||||
| @ -23,6 +26,9 @@ class EmbeddingsWithPositionalEncoding(Module): | ||||
|  | ||||
|  | ||||
| class EmbeddingsWithLearnedPositionalEncoding(Module): | ||||
|     """ | ||||
|     ## Embed tokenas and add parameterized positional encodings | ||||
|     """ | ||||
|     def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): | ||||
|         super().__init__() | ||||
|         self.linear = nn.Embedding(n_vocab, d_model) | ||||
| @ -35,6 +41,9 @@ class EmbeddingsWithLearnedPositionalEncoding(Module): | ||||
|  | ||||
|  | ||||
| class FeedForward(Module): | ||||
|     """ | ||||
|     ## Position-wise feed-forward network with hidden layer | ||||
|     """ | ||||
|     def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): | ||||
|         super().__init__() | ||||
|         self.layer1 = nn.Linear(d_model, d_ff) | ||||
| @ -49,6 +58,20 @@ class FeedForward(Module): | ||||
|  | ||||
|  | ||||
| class TransformerLayer(Module): | ||||
|     """ | ||||
|     ## Transformer Layer | ||||
|  | ||||
|     This can act as a encoder layer or a decoder layer. | ||||
|  | ||||
|     🗒 Some implementations, including the paper seem to have differences | ||||
|     in where the layer-normalization is done. | ||||
|     Here we do a layer normalization before attention and feed-forward networks, | ||||
|     and add the original residual vectors. | ||||
|     Alternative is to do a layer normalzation after adding the residuals. | ||||
|     But we found this to be less stable when training. | ||||
|     We found a detailed discussion about this in paper | ||||
|      [On Layer Normalization in the Transformer Architecture](https://arxiv.org/abs/2002.04745). | ||||
|     """ | ||||
|     def __init__(self, *, | ||||
|                  d_model: int, | ||||
|                  self_attn: MultiHeadAttention, | ||||
| @ -71,47 +94,77 @@ class TransformerLayer(Module): | ||||
|                  mask: torch.Tensor, | ||||
|                  src: torch.Tensor = None, | ||||
|                  src_mask: torch.Tensor = None): | ||||
|         # Normalize the vectors before doing self attention | ||||
|         z = self.norm_self_attn(x) | ||||
|         attn_self = self.self_attn(query=z, key=z, value=z, mask=mask) | ||||
|         x = x + self.dropout(attn_self) | ||||
|         # Run through self attention, i.e. keys and values are from self | ||||
|         self_attn = self.self_attn(query=z, key=z, value=z, mask=mask) | ||||
|         # Add the self attention results | ||||
|         x = x + self.dropout(self_attn) | ||||
|  | ||||
|         # If a source is provided, get results from attention to source. | ||||
|         # This is when you have a decoder layer that pays attention to  | ||||
|         # encoder outputs | ||||
|         if src is not None: | ||||
|             # Normalize vectors | ||||
|             z = self.norm_src_attn(x) | ||||
|             # Attention to source. i.e. keys and values are from source | ||||
|             attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask) | ||||
|             # Add the source attention results | ||||
|             x = x + self.dropout(attn_src) | ||||
|  | ||||
|         # Normalize for feed-forward | ||||
|         z = self.norm_ff(x) | ||||
|         # Pass through the feed-forward network | ||||
|         ff = self.feed_forward(z) | ||||
|         # Add the feed-forward results back | ||||
|         x = x + self.dropout(ff) | ||||
|  | ||||
|         return x | ||||
|  | ||||
|  | ||||
| class Encoder(Module): | ||||
|     """ | ||||
|     ## Transformer Encoder | ||||
|     """ | ||||
|     def __init__(self, layer: TransformerLayer, n_layers: int): | ||||
|         super().__init__() | ||||
|         # Make copies of the transformer layer | ||||
|         self.layers = clone_module_list(layer, n_layers) | ||||
|         self.norm = nn.LayerNorm([layer.size]) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor, mask: torch.Tensor): | ||||
|         # Run through each transformer layer | ||||
|         for layer in self.layers: | ||||
|             x = layer(x=x, mask=mask) | ||||
|         # Finally, normalize the vectors | ||||
|         return self.norm(x) | ||||
|  | ||||
|  | ||||
| class Decoder(Module): | ||||
|     """ | ||||
|     ## Transformer Decoder | ||||
|     """ | ||||
|     def __init__(self, layer: TransformerLayer, n_layers: int): | ||||
|         super().__init__() | ||||
|         # Make copies of the transformer layer | ||||
|         self.layers = clone_module_list(layer, n_layers) | ||||
|         self.norm = nn.LayerNorm([layer.size]) | ||||
|  | ||||
|     def __call__(self, x, memory, src_mask, tgt_mask): | ||||
|     def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): | ||||
|         # Run through each transformer layer | ||||
|         for layer in self.layers: | ||||
|             x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask) | ||||
|         # Finally, normalize the vectors | ||||
|         return self.norm(x) | ||||
|  | ||||
|  | ||||
| class Generator(Module): | ||||
|     """ | ||||
|     ## Generator | ||||
|  | ||||
|     This predicts the tokens and gives the lof softmaxes of those. | ||||
|     You don't need this if you are using `nn.CrossEntropyLoss`. | ||||
|     """ | ||||
|     def __init__(self, n_vocab: int, d_model: int): | ||||
|         super().__init__() | ||||
|         self.projection = nn.Linear(d_model, n_vocab) | ||||
| @ -121,6 +174,9 @@ class Generator(Module): | ||||
|  | ||||
|  | ||||
| class EncoderDecoder(Module): | ||||
|     """ | ||||
|     ## Combined Encoder-Decoder | ||||
|     """ | ||||
|     def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module): | ||||
|         super().__init__() | ||||
|         self.encoder = encoder | ||||
| @ -135,10 +191,11 @@ class EncoderDecoder(Module): | ||||
|             if p.dim() > 1: | ||||
|                 nn.init.xavier_uniform_(p) | ||||
|  | ||||
|     def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, | ||||
|                  tgt_mask: torch.Tensor): | ||||
|         return self.decode(self.encode(src, src_mask), src_mask, | ||||
|                            tgt, tgt_mask) | ||||
|     def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor): | ||||
|         # Runs the source through encoder | ||||
|         enc = self.encode(src, src_mask) | ||||
|         # Run encodings and targets through decoder | ||||
|         return self.decode(enc, src_mask, tgt, tgt_mask) | ||||
|  | ||||
|     def encode(self, src: torch.Tensor, src_mask: torch.Tensor): | ||||
|         return self.encoder(self.src_embed(src), src_mask) | ||||
|  | ||||
| @ -1,3 +1,18 @@ | ||||
| """ | ||||
| # Fixed Positional Encodings | ||||
|  | ||||
| The positional encoding encodes the position along the sequence into | ||||
|  a vector of size `d_model`. | ||||
|  | ||||
| \begin{align} | ||||
| PE_{p,2i} &= sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) \\ | ||||
| PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) | ||||
| \end{align} | ||||
|  | ||||
| Where $1 \leq 2i, 2i + 1 \leq d_{model}$ are the feature indexes in the encoding, | ||||
| and $p$ is the position. | ||||
| """ | ||||
|  | ||||
| import math | ||||
|  | ||||
| import matplotlib.pyplot as plt | ||||
| @ -23,12 +38,20 @@ class PositionalEncoding(Module): | ||||
|  | ||||
|  | ||||
| def get_positional_encoding(d_model: int, max_len: int = 5000): | ||||
|     # Empty encodings vectors | ||||
|     encodings = torch.zeros(max_len, d_model) | ||||
|     # Position indexes | ||||
|     position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | ||||
|     # $2 * i$ | ||||
|     two_i = torch.arange(0, d_model, 2, dtype=torch.float32) | ||||
|     # $10000^{\frac{2i}{d_{model}}$ | ||||
|     div_term = torch.exp(two_i * -(math.log(10000.0) / d_model)) | ||||
|     # $PE_{p,2i} = sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$ | ||||
|     encodings[:, 0::2] = torch.sin(position * div_term) | ||||
|     # $PE_{p,2i + 1} = cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)$ | ||||
|     encodings[:, 1::2] = torch.cos(position * div_term) | ||||
|  | ||||
|     # Add batch dimension | ||||
|     encodings = encodings.unsqueeze(1).requires_grad_(False) | ||||
|  | ||||
|     return encodings | ||||
|  | ||||
| @ -1,6 +1,4 @@ | ||||
| """ | ||||
| <a class="github-button" href="https://github.com/lab-ml/labml_nn" data-size="large" data-show-count="true" aria-label="Star lab-ml/labml_nn on GitHub">Star</a> | ||||
|  | ||||
| # Relative Multi-head Attention | ||||
|  | ||||
| This is an implementation of  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri