Files
Varuna Jayasiri 89ca5604be ♻️ models and cofigs
2020-09-04 14:01:29 +05:30

148 lines
5.0 KiB
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from labml_helpers.module import Module
from labml_nn.utils import clone_module_list
from .mha import MultiHeadAttention
from .positional_encoding import get_positional_encoding
class EmbeddingsWithPositionalEncoding(Module):
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
self.d_model = d_model
self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
def __call__(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
return self.linear(x) * math.sqrt(self.d_model) + pe
class EmbeddingsWithLearnedPositionalEncoding(Module):
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
self.d_model = d_model
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model))
def __call__(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]]
return self.linear(x) * math.sqrt(self.d_model) + pe
class FeedForward(Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.layer1 = nn.Linear(d_model, d_ff)
self.layer2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def __call__(self, x: torch.Tensor):
x = self.layer1(x)
x = F.relu(x)
x = self.dropout(x)
return self.layer2(x)
class TransformerLayer(Module):
def __init__(self, *,
d_model: int,
self_attn: MultiHeadAttention,
src_attn: MultiHeadAttention = None,
feed_forward: FeedForward,
dropout_prob: float):
super().__init__()
self.size = d_model
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.dropout = nn.Dropout(dropout_prob)
self.norm_self_attn = nn.LayerNorm([d_model])
if self.src_attn is not None:
self.norm_src_attn = nn.LayerNorm([d_model])
self.norm_ff = nn.LayerNorm([d_model])
def __call__(self, *,
x: torch.Tensor,
mask: torch.Tensor,
src: torch.Tensor = None,
src_mask: torch.Tensor = None):
z = self.norm_self_attn(x)
attn_self = self.self_attn(query=z, key=z, value=z, mask=mask)
x = x + self.dropout(attn_self)
if src is not None:
z = self.norm_src_attn(x)
attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
x = x + self.dropout(attn_src)
z = self.norm_ff(x)
ff = self.feed_forward(z)
x = x + self.dropout(ff)
return x
class Encoder(Module):
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
self.layers = clone_module_list(layer, n_layers)
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x: torch.Tensor, mask: torch.Tensor):
for layer in self.layers:
x = layer(x=x, mask=mask)
return self.norm(x)
class Decoder(Module):
def __init__(self, layer: TransformerLayer, n_layers: int):
super().__init__()
self.layers = clone_module_list(layer, n_layers)
self.norm = nn.LayerNorm([layer.size])
def __call__(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
return self.norm(x)
class Generator(Module):
def __init__(self, n_vocab: int, d_model: int):
super().__init__()
self.projection = nn.Linear(d_model, n_vocab)
def __call__(self, x):
return F.log_softmax(self.projection(x), dim=-1)
class EncoderDecoder(Module):
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.generator = generator
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor,
tgt_mask: torch.Tensor):
return self.decode(self.encode(src, src_mask), src_mask,
tgt, tgt_mask)
def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)