diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..bee8a64b --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/transformers/__init__.py b/transformers/__init__.py new file mode 100644 index 00000000..e5708f9d --- /dev/null +++ b/transformers/__init__.py @@ -0,0 +1,232 @@ +import copy +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from common.models import clone_module_list +from labml.configs import BaseConfigs, option +from labml.helpers.pytorch.module import Module +from transformer.models.multi_headed_attention import MultiHeadedAttention +from transformers.positional_encoding import PositionalEncoding, get_positional_encoding + + +class EmbeddingsWithPositionalEncoding(Module): + def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): + super().__init__() + self.linear = nn.Embedding(n_vocab, d_model) + self.d_model = d_model + self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len)) + + def __call__(self, x: torch.Tensor): + pe = self.positional_encodings[:x.shape[0]].requires_grad_(False) + return self.linear(x) * math.sqrt(self.d_model) + pe + + +class FeedForward(Module): + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): + super().__init__() + self.layer1 = nn.Linear(d_model, d_ff) + self.layer2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def __call__(self, x: torch.Tensor): + x = self.layer1(x) + x = F.relu(x) + x = self.dropout(x) + return self.layer2(x) + + +class TransformerLayer(Module): + def __init__(self, *, + d_model: int, + self_attn: MultiHeadedAttention, + src_attn: MultiHeadedAttention = None, + feed_forward: FeedForward, + dropout_prob: float): + super().__init__() + self.size = d_model + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.dropout = nn.Dropout(dropout_prob) + self.norm_self_attn = nn.LayerNorm([d_model]) + if self.src_attn is not None: + self.norm_src_attn = nn.LayerNorm([d_model]) + self.norm_ff = nn.LayerNorm([d_model]) + + def __call__(self, *, + x: torch.Tensor, + mask: torch.Tensor, + src: torch.Tensor = None, + src_mask: torch.Tensor = None): + z = self.norm_self_attn(x) + attn_self = self.self_attn(query=z, key=z, value=z, mask=mask) + x = x + self.dropout(attn_self) + + if src is not None: + z = self.norm_src_attn(x) + attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask) + x = x + self.dropout(attn_src) + + z = self.norm_ff(x) + ff = self.feed_forward(z) + x = x + self.dropout(ff) + + # guard(x.shape, attn_self.shape, attn_src.shape, ff.shape, + # '_batch_size', '_seq_len', 'd_model') + + return x + + +class Encoder(Module): + def __init__(self, layer: TransformerLayer, n_layers: int): + super().__init__() + self.layers = clone_module_list(layer, n_layers) + self.norm = nn.LayerNorm([layer.size]) + + def __call__(self, x: torch.Tensor, mask: torch.Tensor): + for layer in self.layers: + x = layer(x=x, mask=mask) + return self.norm(x) + + +class Decoder(Module): + def __init__(self, layer: TransformerLayer, n_layers: int): + super().__init__() + self.layers = clone_module_list(layer, n_layers) + self.norm = nn.LayerNorm([layer.size]) + + def __call__(self, x, memory, src_mask, tgt_mask): + for layer in self.layers: + x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask) + return self.norm(x) + + +class Generator(Module): + def __init__(self, n_vocab: int, d_model: int): + super().__init__() + self.projection = nn.Linear(d_model, n_vocab) + + def __call__(self, x): + return F.log_softmax(self.projection(x), dim=-1) + + +class EncoderDecoder(Module): + def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.generator = generator + + # This was important from their code. + # Initialize parameters with Glorot / fan_avg. + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, + tgt_mask: torch.Tensor): + return self.decode(self.encode(src, src_mask), src_mask, + tgt, tgt_mask) + + def encode(self, src: torch.Tensor, src_mask: torch.Tensor): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor): + return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) + + +class TransformerConfigs(BaseConfigs): + n_heads: int = 8 + d_model: int = 512 + n_layers: int = 6 + d_ff: int = 2048 + dropout: float = 0.1 + n_src_vocab: int + n_tgt_vocab: int + + encoder_attn: MultiHeadedAttention = 'mha' + decoder_attn: MultiHeadedAttention = 'mha' + decoder_mem_attn: MultiHeadedAttention = 'mha' + feed_forward: FeedForward + + encoder_layer: TransformerLayer = 'normal' + decoder_layer: TransformerLayer = 'normal' + + encoder: Encoder = 'normal' + decoder: Decoder = 'normal' + + src_embed: Module = 'fixed_pos' + tgt_embed: Module = 'fixed_pos' + + generator: Generator = 'default' + + encoder_decoder: EncoderDecoder + + +@option(TransformerConfigs.feed_forward, 'default') +def _feed_forward(c: TransformerConfigs): + return FeedForward(c.d_model, c.d_ff, c.dropout) + + +@option(TransformerConfigs.encoder_attn, 'mha') +def _encoder_mha(c: TransformerConfigs): + return MultiHeadedAttention(c.n_heads, c.d_model) + + +@option(TransformerConfigs.decoder_attn, 'mha') +def _decoder_mha(c: TransformerConfigs): + return MultiHeadedAttention(c.n_heads, c.d_model) + + +@option(TransformerConfigs.decoder_mem_attn, 'mha') +def _decoder_mem_mha(c: TransformerConfigs): + return MultiHeadedAttention(c.n_heads, c.d_model) + + +@option(TransformerConfigs.encoder_layer, 'normal') +def _encoder_layer(c: TransformerConfigs): + return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn, + src_attn=None, feed_forward=copy.deepcopy(c.feed_forward), + dropout_prob=c.dropout) + + +@option(TransformerConfigs.decoder_layer, 'normal') +def _decoder_layer(c: TransformerConfigs): + return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn, + src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.feed_forward), + dropout_prob=c.dropout) + + +@option(TransformerConfigs.encoder, 'normal') +def _encoder(c: TransformerConfigs): + return Encoder(c.encoder_layer, c.n_layers) + + +@option(TransformerConfigs.decoder, 'normal') +def _decoder(c: TransformerConfigs): + return Decoder(c.decoder_layer, c.n_layers) + + +@option(TransformerConfigs.generator, 'default') +def _generator(c: TransformerConfigs): + return Generator(c.n_tgt_vocab, c.d_model) + + +@option(TransformerConfigs.src_embed, 'fixed_pos') +def _src_embed_with_positional(c: TransformerConfigs): + return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab) + + +@option(TransformerConfigs.tgt_embed, 'fixed_pos') +def _tgt_embed_with_positional(c: TransformerConfigs): + return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab) + + +@option(TransformerConfigs.encoder_decoder, 'normal') +def _encoder_decoder(c: TransformerConfigs): + return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator) diff --git a/transformers/label_smoothing_loss.py b/transformers/label_smoothing_loss.py new file mode 100644 index 00000000..2dff9f3d --- /dev/null +++ b/transformers/label_smoothing_loss.py @@ -0,0 +1,59 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn + +from labml.helpers.pytorch.module import Module + + +class LabelSmoothingLoss(Module): + def __init__(self, size: int, padding_idx: int, smoothing: float = 0.0): + super().__init__() + self.loss = nn.KLDivLoss(reduction='sum') + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + + def __call__(self, x: torch.Tensor, target: torch.Tensor): + assert x.size(1) == self.size + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 2)) + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + true_dist[:, self.padding_idx] = 0 + mask = torch.nonzero(target == self.padding_idx, as_tuple=False) + if mask.dim() > 0: + true_dist.index_fill_(0, mask.squeeze(), 0.0) + self.true_dist = true_dist + return self.loss(x, true_dist.detach()) + + +def _test_label_smoothing(): + smooth_loss = LabelSmoothingLoss(5, 0, 0.4) + predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0], + [0, 0.2, 0.7, 0.1, 0], + [0, 0.2, 0.7, 0.1, 0]], dtype=torch.float) + _ = smooth_loss(predict.log(), + torch.tensor([2, 1, 0], dtype=torch.long)) + + # Show the target distributions expected by the system. + plt.imshow(smooth_loss.true_dist) + plt.show() + + smooth_loss = LabelSmoothingLoss(5, 0, 0.1) + + def loss_sample(x): + d = x + 3 * 1 + predict2 = torch.tensor([[0, x / d, 1 / d, 1 / d, 1 / d], + ], dtype=torch.float) + # print(predict) + return smooth_loss(predict2.log(), + torch.tensor([1], dtype=torch.long)).item() + + plt.plot(np.arange(1, 100), [loss_sample(x) for x in range(1, 100)]) + plt.show() + + +if __name__ == '__main__': + _test_label_smoothing() diff --git a/transformers/positional_encoding.py b/transformers/positional_encoding.py new file mode 100644 index 00000000..6421e8d5 --- /dev/null +++ b/transformers/positional_encoding.py @@ -0,0 +1,47 @@ +import math + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn + +from labml.helpers.pytorch.module import Module + + +class PositionalEncoding(Module): + def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(dropout_prob) + + self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len)) + + def __call__(self, x: torch.Tensor): + pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False) + x = x + pe + x = self.dropout(x) + return x + + +def get_positional_encoding(d_model: int, max_len: int = 5000): + encodings = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) + two_i = torch.arange(0, d_model, 2, dtype=torch.float32) + div_term = torch.exp(two_i * -(math.log(10000.0) / d_model)) + encodings[:, 0::2] = torch.sin(position * div_term) + encodings[:, 1::2] = torch.cos(position * div_term) + encodings = encodings.unsqueeze(1).requires_grad_(False) + + return encodings + + +def _test_positional_encoding(): + plt.figure(figsize=(15, 5)) + pe = get_positional_encoding(20, 100) + plt.plot(np.arange(100), pe[:, 0, 4:8].numpy()) + plt.legend(["dim %d" % p for p in [4, 5, 6, 7]]) + plt.title("Positional encoding") + plt.show() + + +if __name__ == '__main__': + _test_positional_encoding()