mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			233 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import copy
 | |
| import math
 | |
| 
 | |
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| 
 | |
| from common.models import clone_module_list
 | |
| from labml.configs import BaseConfigs, option
 | |
| from labml.helpers.pytorch.module import Module
 | |
| from transformer.models.multi_headed_attention import MultiHeadedAttention
 | |
| from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
 | |
| 
 | |
| 
 | |
| class EmbeddingsWithPositionalEncoding(Module):
 | |
|     def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
 | |
|         super().__init__()
 | |
|         self.linear = nn.Embedding(n_vocab, d_model)
 | |
|         self.d_model = d_model
 | |
|         self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
 | |
| 
 | |
|     def __call__(self, x: torch.Tensor):
 | |
|         pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
 | |
|         return self.linear(x) * math.sqrt(self.d_model) + pe
 | |
| 
 | |
| 
 | |
| class FeedForward(Module):
 | |
|     def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
 | |
|         super().__init__()
 | |
|         self.layer1 = nn.Linear(d_model, d_ff)
 | |
|         self.layer2 = nn.Linear(d_ff, d_model)
 | |
|         self.dropout = nn.Dropout(dropout)
 | |
| 
 | |
|     def __call__(self, x: torch.Tensor):
 | |
|         x = self.layer1(x)
 | |
|         x = F.relu(x)
 | |
|         x = self.dropout(x)
 | |
|         return self.layer2(x)
 | |
| 
 | |
| 
 | |
| class TransformerLayer(Module):
 | |
|     def __init__(self, *,
 | |
|                  d_model: int,
 | |
|                  self_attn: MultiHeadedAttention,
 | |
|                  src_attn: MultiHeadedAttention = None,
 | |
|                  feed_forward: FeedForward,
 | |
|                  dropout_prob: float):
 | |
|         super().__init__()
 | |
|         self.size = d_model
 | |
|         self.self_attn = self_attn
 | |
|         self.src_attn = src_attn
 | |
|         self.feed_forward = feed_forward
 | |
|         self.dropout = nn.Dropout(dropout_prob)
 | |
|         self.norm_self_attn = nn.LayerNorm([d_model])
 | |
|         if self.src_attn is not None:
 | |
|             self.norm_src_attn = nn.LayerNorm([d_model])
 | |
|         self.norm_ff = nn.LayerNorm([d_model])
 | |
| 
 | |
|     def __call__(self, *,
 | |
|                  x: torch.Tensor,
 | |
|                  mask: torch.Tensor,
 | |
|                  src: torch.Tensor = None,
 | |
|                  src_mask: torch.Tensor = None):
 | |
|         z = self.norm_self_attn(x)
 | |
|         attn_self = self.self_attn(query=z, key=z, value=z, mask=mask)
 | |
|         x = x + self.dropout(attn_self)
 | |
| 
 | |
|         if src is not None:
 | |
|             z = self.norm_src_attn(x)
 | |
|             attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
 | |
|             x = x + self.dropout(attn_src)
 | |
| 
 | |
|         z = self.norm_ff(x)
 | |
|         ff = self.feed_forward(z)
 | |
|         x = x + self.dropout(ff)
 | |
| 
 | |
|         # guard(x.shape, attn_self.shape, attn_src.shape, ff.shape,
 | |
|         #       '_batch_size', '_seq_len', 'd_model')
 | |
| 
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class Encoder(Module):
 | |
|     def __init__(self, layer: TransformerLayer, n_layers: int):
 | |
|         super().__init__()
 | |
|         self.layers = clone_module_list(layer, n_layers)
 | |
|         self.norm = nn.LayerNorm([layer.size])
 | |
| 
 | |
|     def __call__(self, x: torch.Tensor, mask: torch.Tensor):
 | |
|         for layer in self.layers:
 | |
|             x = layer(x=x, mask=mask)
 | |
|         return self.norm(x)
 | |
| 
 | |
| 
 | |
| class Decoder(Module):
 | |
|     def __init__(self, layer: TransformerLayer, n_layers: int):
 | |
|         super().__init__()
 | |
|         self.layers = clone_module_list(layer, n_layers)
 | |
|         self.norm = nn.LayerNorm([layer.size])
 | |
| 
 | |
|     def __call__(self, x, memory, src_mask, tgt_mask):
 | |
|         for layer in self.layers:
 | |
|             x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
 | |
|         return self.norm(x)
 | |
| 
 | |
| 
 | |
| class Generator(Module):
 | |
|     def __init__(self, n_vocab: int, d_model: int):
 | |
|         super().__init__()
 | |
|         self.projection = nn.Linear(d_model, n_vocab)
 | |
| 
 | |
|     def __call__(self, x):
 | |
|         return F.log_softmax(self.projection(x), dim=-1)
 | |
| 
 | |
| 
 | |
| class EncoderDecoder(Module):
 | |
|     def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
 | |
|         super().__init__()
 | |
|         self.encoder = encoder
 | |
|         self.decoder = decoder
 | |
|         self.src_embed = src_embed
 | |
|         self.tgt_embed = tgt_embed
 | |
|         self.generator = generator
 | |
| 
 | |
|         # This was important from their code.
 | |
|         # Initialize parameters with Glorot / fan_avg.
 | |
|         for p in self.parameters():
 | |
|             if p.dim() > 1:
 | |
|                 nn.init.xavier_uniform_(p)
 | |
| 
 | |
|     def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor,
 | |
|                  tgt_mask: torch.Tensor):
 | |
|         return self.decode(self.encode(src, src_mask), src_mask,
 | |
|                            tgt, tgt_mask)
 | |
| 
 | |
|     def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
 | |
|         return self.encoder(self.src_embed(src), src_mask)
 | |
| 
 | |
|     def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
 | |
|         return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
 | |
| 
 | |
| 
 | |
| class TransformerConfigs(BaseConfigs):
 | |
|     n_heads: int = 8
 | |
|     d_model: int = 512
 | |
|     n_layers: int = 6
 | |
|     d_ff: int = 2048
 | |
|     dropout: float = 0.1
 | |
|     n_src_vocab: int
 | |
|     n_tgt_vocab: int
 | |
| 
 | |
|     encoder_attn: MultiHeadedAttention = 'mha'
 | |
|     decoder_attn: MultiHeadedAttention = 'mha'
 | |
|     decoder_mem_attn: MultiHeadedAttention = 'mha'
 | |
|     feed_forward: FeedForward
 | |
| 
 | |
|     encoder_layer: TransformerLayer = 'normal'
 | |
|     decoder_layer: TransformerLayer = 'normal'
 | |
| 
 | |
|     encoder: Encoder = 'normal'
 | |
|     decoder: Decoder = 'normal'
 | |
| 
 | |
|     src_embed: Module = 'fixed_pos'
 | |
|     tgt_embed: Module = 'fixed_pos'
 | |
| 
 | |
|     generator: Generator = 'default'
 | |
| 
 | |
|     encoder_decoder: EncoderDecoder
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.feed_forward, 'default')
 | |
| def _feed_forward(c: TransformerConfigs):
 | |
|     return FeedForward(c.d_model, c.d_ff, c.dropout)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.encoder_attn, 'mha')
 | |
| def _encoder_mha(c: TransformerConfigs):
 | |
|     return MultiHeadedAttention(c.n_heads, c.d_model)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.decoder_attn, 'mha')
 | |
| def _decoder_mha(c: TransformerConfigs):
 | |
|     return MultiHeadedAttention(c.n_heads, c.d_model)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.decoder_mem_attn, 'mha')
 | |
| def _decoder_mem_mha(c: TransformerConfigs):
 | |
|     return MultiHeadedAttention(c.n_heads, c.d_model)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.encoder_layer, 'normal')
 | |
| def _encoder_layer(c: TransformerConfigs):
 | |
|     return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
 | |
|                             src_attn=None, feed_forward=copy.deepcopy(c.feed_forward),
 | |
|                             dropout_prob=c.dropout)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.decoder_layer, 'normal')
 | |
| def _decoder_layer(c: TransformerConfigs):
 | |
|     return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
 | |
|                             src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.feed_forward),
 | |
|                             dropout_prob=c.dropout)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.encoder, 'normal')
 | |
| def _encoder(c: TransformerConfigs):
 | |
|     return Encoder(c.encoder_layer, c.n_layers)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.decoder, 'normal')
 | |
| def _decoder(c: TransformerConfigs):
 | |
|     return Decoder(c.decoder_layer, c.n_layers)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.generator, 'default')
 | |
| def _generator(c: TransformerConfigs):
 | |
|     return Generator(c.n_tgt_vocab, c.d_model)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.src_embed, 'fixed_pos')
 | |
| def _src_embed_with_positional(c: TransformerConfigs):
 | |
|     return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.tgt_embed, 'fixed_pos')
 | |
| def _tgt_embed_with_positional(c: TransformerConfigs):
 | |
|     return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
 | |
| 
 | |
| 
 | |
| @option(TransformerConfigs.encoder_decoder, 'normal')
 | |
| def _encoder_decoder(c: TransformerConfigs):
 | |
|     return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
 |