fix dependecies and include relative attention

This commit is contained in:
Varuna Jayasiri
2020-08-25 15:54:47 +05:30
parent 4ae3d770a9
commit 5ccc095c2a
4 changed files with 188 additions and 17 deletions

View File

@ -5,11 +5,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from common.models import clone_module_list
from labml.configs import BaseConfigs, option
from labml.configs import BaseConfigs, option, calculate
from labml.helpers.pytorch.module import Module
from transformer.models.multi_headed_attention import MultiHeadedAttention
from transformers.mha import MultiHeadAttention
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
from transformers.utils import clone_module_list
class EmbeddingsWithPositionalEncoding(Module):
@ -53,8 +53,8 @@ class FeedForward(Module):
class TransformerLayer(Module):
def __init__(self, *,
d_model: int,
self_attn: MultiHeadedAttention,
src_attn: MultiHeadedAttention = None,
self_attn: MultiHeadAttention,
src_attn: MultiHeadAttention = None,
feed_forward: FeedForward,
dropout_prob: float):
super().__init__()
@ -161,9 +161,9 @@ class TransformerConfigs(BaseConfigs):
n_src_vocab: int
n_tgt_vocab: int
encoder_attn: MultiHeadedAttention = 'mha'
decoder_attn: MultiHeadedAttention = 'mha'
decoder_mem_attn: MultiHeadedAttention = 'mha'
encoder_attn: MultiHeadAttention = 'mha'
decoder_attn: MultiHeadAttention = 'mha'
decoder_mem_attn: MultiHeadAttention = 'mha'
feed_forward: FeedForward
encoder_layer: TransformerLayer = 'normal'
@ -185,19 +185,25 @@ def _feed_forward(c: TransformerConfigs):
return FeedForward(c.d_model, c.d_ff, c.dropout)
@option(TransformerConfigs.encoder_attn, 'mha')
def _encoder_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
### MHA
def _mha(c: TransformerConfigs):
return MultiHeadAttention(c.n_heads, c.d_model)
@option(TransformerConfigs.decoder_attn, 'mha')
def _decoder_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
@option(TransformerConfigs.decoder_mem_attn, 'mha')
def _decoder_mem_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
### Relative MHA
def _relative_mha(c: TransformerConfigs):
from transformers.relative_mha import RelativeMultiHeadAttention
return RelativeMultiHeadAttention(c.n_heads, c.d_model)
calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
@option(TransformerConfigs.encoder_layer, 'normal')
@ -229,6 +235,7 @@ def _generator(c: TransformerConfigs):
return Generator(c.n_tgt_vocab, c.d_model)
### Positional Embeddings
@option(TransformerConfigs.src_embed, 'fixed_pos')
def _src_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
@ -239,6 +246,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
### Learned Positional Embeddings
@option(TransformerConfigs.src_embed, 'learned_pos')
def _src_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
@ -249,6 +257,16 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
### No Positional Embeddings
@option(TransformerConfigs.src_embed, 'no_pos')
def _src_embed_without_positional(c: TransformerConfigs):
return nn.Embedding(c.n_src_vocab, c.d_model)
@option(TransformerConfigs.tgt_embed, 'no_pos')
def _tgt_embed_without_positional(c: TransformerConfigs):
return nn.Embedding(c.n_tgt_vocab, c.d_model)
@option(TransformerConfigs.encoder_decoder, 'normal')
def _encoder_decoder(c: TransformerConfigs):

78
transformers/mha.py Normal file
View File

@ -0,0 +1,78 @@
import math
from typing import Optional
import torch
from torch import nn as nn
from torch.nn import functional as F
from labml.helpers.pytorch.module import Module
class PrepareForMultiHeadAttention(Module):
def __init__(self, d_model: int, heads: int, d_k: int):
super().__init__()
self.linear = nn.Linear(d_model, heads * d_k)
self.heads = heads
self.d_k = d_k
def __call__(self, x: torch.Tensor):
seq_len, batch_size, _ = x.shape
x = self.linear(x)
x = x.view(seq_len, batch_size, self.heads, self.d_k)
return x
class MultiHeadAttention(Module):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
super().__init__()
# We assume d_v always equals d_k
self.d_k = d_model // heads
self.heads = heads
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
self.output = nn.Linear(d_model, d_model)
self.attn = None
self.dropout = nn.Dropout(dropout_prob)
self.scale = 1 / math.sqrt(self.d_k)
def get_scores(self, query: torch.Tensor,
key: torch.Tensor, ):
return torch.einsum('ibhd,jbhd->ijbh', query, key)
def __call__(self, *,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):
seq_len, batch_size, *_ = query.shape
if mask is not None:
# mask = ijb
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
# Same mask applied to all h heads.
mask = mask.unsqueeze(-1)
query = self.query(query)
key = self.key(key)
value = self.value(value)
scores = self.get_scores(query, key)
scores *= self.scale
if mask is not None:
# mask = ijbh
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
scores = scores.masked_fill(mask == 0, -1e9)
attn = F.softmax(scores, dim=1)
attn = self.dropout(attn)
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
self.attn = attn.detach()
x = x.reshape(seq_len, batch_size, -1)
return self.output(x)

View File

@ -0,0 +1,66 @@
import copy
import torch
from torch import nn
from labml.helpers.pytorch.module import Module
from transformers.mha import MultiHeadAttention
class PrepareForMultiHeadAttention(Module):
def __init__(self, d_model: int, heads: int, d_k: int):
super().__init__()
self.linear = nn.Linear(d_model, heads * d_k, bias=False)
self.heads = heads
self.d_k = d_k
def __call__(self, x: torch.Tensor):
seq_len, batch_size, _ = x.shape
x = self.linear(x)
x = x.view(seq_len, batch_size, self.heads, self.d_k)
return x
class RelativeMultiHeadAttention(MultiHeadAttention):
@staticmethod
def _rel_shift(x: torch.Tensor):
zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
x = x_padded[1:].view_as(x)
ones = torch.ones((x.size(0), x.size(1)), device=x.device)
lower_triangle = torch.tril(ones, x.size(1) - x.size(0))
x = x * lower_triangle[:, :, None, None]
return x
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
super().__init__(heads, d_model, dropout_prob)
self.max_key_len = 2 ** 12
self.key_pos_embeddings = nn.Parameter(
torch.zeros((self.max_key_len, heads, self.d_k)),
requires_grad=True)
self.query_pos_bias = nn.Parameter(
torch.zeros((heads, self.d_k)),
requires_grad=True)
self.key_pos_bias = nn.Parameter(
torch.zeros((self.max_key_len, heads)),
requires_grad=True)
def get_scores(self, query: torch.Tensor,
key: torch.Tensor, ):
key_len = key.shape[0]
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:])
d = self.key_pos_bias[None, -key_len:, None, :]
bd = self._rel_shift(b + d)
return ac + bd

9
transformers/utils.py Normal file
View File

@ -0,0 +1,9 @@
import copy
from torch import nn
from labml.helpers.pytorch.module import Module
def clone_module_list(module: Module, n: int):
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])