basic transformer

This commit is contained in:
Varuna Jayasiri
2020-08-25 15:35:25 +05:30
parent 7393d91161
commit 5a7a2e0525
4 changed files with 339 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

232
transformers/__init__.py Normal file
View File

@ -0,0 +1,232 @@
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from common.models import clone_module_list
from labml.configs import BaseConfigs, option
from labml.helpers.pytorch.module import Module
from transformer.models.multi_headed_attention import MultiHeadedAttention
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
class EmbeddingsWithPositionalEncoding(Module):
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
self.d_model = d_model
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
def __call__(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
return self.linear(x) * math.sqrt(self.d_model) + pe
class FeedForward(Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.layer1 = nn.Linear(d_model, d_ff)
self.layer2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def __call__(self, x: torch.Tensor):
x = self.layer1(x)
x = F.relu(x)
x = self.dropout(x)
return self.layer2(x)
class TransformerLayer(Module):
def __init__(self, *,
d_model: int,
self_attn: MultiHeadedAttention,
src_attn: MultiHeadedAttention = None,
feed_forward: FeedForward,
dropout_prob: float):
super().__init__()
self.size = d_model
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.dropout = nn.Dropout(dropout_prob)
self.norm_self_attn = nn.LayerNorm([d_model])
if self.src_attn is not None:
self.norm_src_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def __call__(self, *,
x: torch.Tensor,
mask: torch.Tensor,
src: torch.Tensor = None,
src_mask: torch.Tensor = None):
z = self.norm_self_attn(x)
attn_self = self.self_attn(query=z, key=z, value=z, mask=mask)
x = x + self.dropout(attn_self)
if src is not None:
z = self.norm_src_attn(x)
attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
x = x + self.dropout(attn_src)
z = self.norm_ff(x)
ff = self.feed_forward(z)
x = x + self.dropout(ff)
# guard(x.shape, attn_self.shape, attn_src.shape, ff.shape,
# '_batch_size', '_seq_len', 'd_model')
return x
class Encoder(Module):
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
self.layers = clone_module_list(layer, n_layers)
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x: torch.Tensor, mask: torch.Tensor):
for layer in self.layers:
x = layer(x=x, mask=mask)
return self.norm(x)
class Decoder(Module):
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
self.layers = clone_module_list(layer, n_layers)
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
return self.norm(x)
class Generator(Module):
def __init__(self, n_vocab: int, d_model: int):
super().__init__()
self.projection = nn.Linear(d_model, n_vocab)
def __call__(self, x):
return F.log_softmax(self.projection(x), dim=-1)
class EncoderDecoder(Module):
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor,
tgt_mask: torch.Tensor):
return self.decode(self.encode(src, src_mask), src_mask,
tgt, tgt_mask)
def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
class TransformerConfigs(BaseConfigs):
n_heads: int = 8
d_model: int = 512
n_layers: int = 6
d_ff: int = 2048
dropout: float = 0.1
n_src_vocab: int
n_tgt_vocab: int
encoder_attn: MultiHeadedAttention = 'mha'
decoder_attn: MultiHeadedAttention = 'mha'
decoder_mem_attn: MultiHeadedAttention = 'mha'
feed_forward: FeedForward
encoder_layer: TransformerLayer = 'normal'
decoder_layer: TransformerLayer = 'normal'
encoder: Encoder = 'normal'
decoder: Decoder = 'normal'
src_embed: Module = 'fixed_pos'
tgt_embed: Module = 'fixed_pos'
generator: Generator = 'default'
encoder_decoder: EncoderDecoder
@option(TransformerConfigs.feed_forward, 'default')
def _feed_forward(c: TransformerConfigs):
return FeedForward(c.d_model, c.d_ff, c.dropout)
@option(TransformerConfigs.encoder_attn, 'mha')
def _encoder_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
@option(TransformerConfigs.decoder_attn, 'mha')
def _decoder_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
@option(TransformerConfigs.decoder_mem_attn, 'mha')
def _decoder_mem_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
@option(TransformerConfigs.encoder_layer, 'normal')
def _encoder_layer(c: TransformerConfigs):
return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
src_attn=None, feed_forward=copy.deepcopy(c.feed_forward),
dropout_prob=c.dropout)
@option(TransformerConfigs.decoder_layer, 'normal')
def _decoder_layer(c: TransformerConfigs):
return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.feed_forward),
dropout_prob=c.dropout)
@option(TransformerConfigs.encoder, 'normal')
def _encoder(c: TransformerConfigs):
return Encoder(c.encoder_layer, c.n_layers)
@option(TransformerConfigs.decoder, 'normal')
def _decoder(c: TransformerConfigs):
return Decoder(c.decoder_layer, c.n_layers)
@option(TransformerConfigs.generator, 'default')
def _generator(c: TransformerConfigs):
return Generator(c.n_tgt_vocab, c.d_model)
@option(TransformerConfigs.src_embed, 'fixed_pos')
def _src_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
@option(TransformerConfigs.tgt_embed, 'fixed_pos')
def _tgt_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
@option(TransformerConfigs.encoder_decoder, 'normal')
def _encoder_decoder(c: TransformerConfigs):
return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)

View File

@ -0,0 +1,59 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from labml.helpers.pytorch.module import Module
class LabelSmoothingLoss(Module):
def __init__(self, size: int, padding_idx: int, smoothing: float = 0.0):
super().__init__()
self.loss = nn.KLDivLoss(reduction='sum')
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
def __call__(self, x: torch.Tensor, target: torch.Tensor):
assert x.size(1) == self.size
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 2))
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
true_dist[:, self.padding_idx] = 0
mask = torch.nonzero(target == self.padding_idx, as_tuple=False)
if mask.dim() > 0:
true_dist.index_fill_(0, mask.squeeze(), 0.0)
self.true_dist = true_dist
return self.loss(x, true_dist.detach())
def _test_label_smoothing():
smooth_loss = LabelSmoothingLoss(5, 0, 0.4)
predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0],
[0, 0.2, 0.7, 0.1, 0],
[0, 0.2, 0.7, 0.1, 0]], dtype=torch.float)
_ = smooth_loss(predict.log(),
torch.tensor([2, 1, 0], dtype=torch.long))
# Show the target distributions expected by the system.
plt.imshow(smooth_loss.true_dist)
plt.show()
smooth_loss = LabelSmoothingLoss(5, 0, 0.1)
def loss_sample(x):
d = x + 3 * 1
predict2 = torch.tensor([[0, x / d, 1 / d, 1 / d, 1 / d],
], dtype=torch.float)
# print(predict)
return smooth_loss(predict2.log(),
torch.tensor([1], dtype=torch.long)).item()
plt.plot(np.arange(1, 100), [loss_sample(x) for x in range(1, 100)])
plt.show()
if __name__ == '__main__':
_test_label_smoothing()

View File

@ -0,0 +1,47 @@
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from labml.helpers.pytorch.module import Module
class PositionalEncoding(Module):
def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(dropout_prob)
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
def __call__(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
x = x + pe
x = self.dropout(x)
return x
def get_positional_encoding(d_model: int, max_len: int = 5000):
encodings = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
encodings[:, 0::2] = torch.sin(position * div_term)
encodings[:, 1::2] = torch.cos(position * div_term)
encodings = encodings.unsqueeze(1).requires_grad_(False)
return encodings
def _test_positional_encoding():
plt.figure(figsize=(15, 5))
pe = get_positional_encoding(20, 100)
plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
plt.title("Positional encoding")
plt.show()
if __name__ == '__main__':
_test_positional_encoding()