From 5ccc095c2a6f7f06b54ec0d5dcf6330c8b699eb0 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 25 Aug 2020 15:54:47 +0530 Subject: [PATCH] fix dependecies and include relative attention --- transformers/__init__.py | 52 ++++++++++++++++-------- transformers/mha.py | 78 ++++++++++++++++++++++++++++++++++++ transformers/relative_mha.py | 66 ++++++++++++++++++++++++++++++ transformers/utils.py | 9 +++++ 4 files changed, 188 insertions(+), 17 deletions(-) create mode 100644 transformers/mha.py create mode 100644 transformers/relative_mha.py create mode 100644 transformers/utils.py diff --git a/transformers/__init__.py b/transformers/__init__.py index 86906796..65b23399 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -5,11 +5,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from common.models import clone_module_list -from labml.configs import BaseConfigs, option +from labml.configs import BaseConfigs, option, calculate from labml.helpers.pytorch.module import Module -from transformer.models.multi_headed_attention import MultiHeadedAttention +from transformers.mha import MultiHeadAttention from transformers.positional_encoding import PositionalEncoding, get_positional_encoding +from transformers.utils import clone_module_list class EmbeddingsWithPositionalEncoding(Module): @@ -53,8 +53,8 @@ class FeedForward(Module): class TransformerLayer(Module): def __init__(self, *, d_model: int, - self_attn: MultiHeadedAttention, - src_attn: MultiHeadedAttention = None, + self_attn: MultiHeadAttention, + src_attn: MultiHeadAttention = None, feed_forward: FeedForward, dropout_prob: float): super().__init__() @@ -161,9 +161,9 @@ class TransformerConfigs(BaseConfigs): n_src_vocab: int n_tgt_vocab: int - encoder_attn: MultiHeadedAttention = 'mha' - decoder_attn: MultiHeadedAttention = 'mha' - decoder_mem_attn: MultiHeadedAttention = 'mha' + encoder_attn: MultiHeadAttention = 'mha' + decoder_attn: MultiHeadAttention = 'mha' + decoder_mem_attn: MultiHeadAttention = 'mha' feed_forward: FeedForward encoder_layer: TransformerLayer = 'normal' @@ -185,19 +185,25 @@ def _feed_forward(c: TransformerConfigs): return FeedForward(c.d_model, c.d_ff, c.dropout) -@option(TransformerConfigs.encoder_attn, 'mha') -def _encoder_mha(c: TransformerConfigs): - return MultiHeadedAttention(c.n_heads, c.d_model) +### MHA +def _mha(c: TransformerConfigs): + return MultiHeadAttention(c.n_heads, c.d_model) -@option(TransformerConfigs.decoder_attn, 'mha') -def _decoder_mha(c: TransformerConfigs): - return MultiHeadedAttention(c.n_heads, c.d_model) +calculate(TransformerConfigs.encoder_attn, 'mha', _mha) +calculate(TransformerConfigs.decoder_attn, 'mha', _mha) +calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha) -@option(TransformerConfigs.decoder_mem_attn, 'mha') -def _decoder_mem_mha(c: TransformerConfigs): - return MultiHeadedAttention(c.n_heads, c.d_model) +### Relative MHA +def _relative_mha(c: TransformerConfigs): + from transformers.relative_mha import RelativeMultiHeadAttention + return RelativeMultiHeadAttention(c.n_heads, c.d_model) + + +calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha) +calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha) +calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha) @option(TransformerConfigs.encoder_layer, 'normal') @@ -229,6 +235,7 @@ def _generator(c: TransformerConfigs): return Generator(c.n_tgt_vocab, c.d_model) +### Positional Embeddings @option(TransformerConfigs.src_embed, 'fixed_pos') def _src_embed_with_positional(c: TransformerConfigs): return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab) @@ -239,6 +246,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs): return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab) +### Learned Positional Embeddings @option(TransformerConfigs.src_embed, 'learned_pos') def _src_embed_with_learned_positional(c: TransformerConfigs): return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab) @@ -249,6 +257,16 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs): return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab) +### No Positional Embeddings +@option(TransformerConfigs.src_embed, 'no_pos') +def _src_embed_without_positional(c: TransformerConfigs): + return nn.Embedding(c.n_src_vocab, c.d_model) + + +@option(TransformerConfigs.tgt_embed, 'no_pos') +def _tgt_embed_without_positional(c: TransformerConfigs): + return nn.Embedding(c.n_tgt_vocab, c.d_model) + @option(TransformerConfigs.encoder_decoder, 'normal') def _encoder_decoder(c: TransformerConfigs): diff --git a/transformers/mha.py b/transformers/mha.py new file mode 100644 index 00000000..933e5181 --- /dev/null +++ b/transformers/mha.py @@ -0,0 +1,78 @@ +import math +from typing import Optional + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from labml.helpers.pytorch.module import Module + + +class PrepareForMultiHeadAttention(Module): + def __init__(self, d_model: int, heads: int, d_k: int): + super().__init__() + self.linear = nn.Linear(d_model, heads * d_k) + self.heads = heads + self.d_k = d_k + + def __call__(self, x: torch.Tensor): + seq_len, batch_size, _ = x.shape + + x = self.linear(x) + x = x.view(seq_len, batch_size, self.heads, self.d_k) + + return x + + +class MultiHeadAttention(Module): + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + super().__init__() + # We assume d_v always equals d_k + self.d_k = d_model // heads + self.heads = heads + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k) + self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k) + self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k) + self.output = nn.Linear(d_model, d_model) + self.attn = None + self.dropout = nn.Dropout(dropout_prob) + self.scale = 1 / math.sqrt(self.d_k) + + def get_scores(self, query: torch.Tensor, + key: torch.Tensor, ): + return torch.einsum('ibhd,jbhd->ijbh', query, key) + + def __call__(self, *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None): + seq_len, batch_size, *_ = query.shape + + if mask is not None: + # mask = ijb + assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1] + # Same mask applied to all h heads. + mask = mask.unsqueeze(-1) + + query = self.query(query) + key = self.key(key) + value = self.value(value) + + scores = self.get_scores(query, key) + + scores *= self.scale + if mask is not None: + # mask = ijbh + assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1] + scores = scores.masked_fill(mask == 0, -1e9) + attn = F.softmax(scores, dim=1) + attn = self.dropout(attn) + + x = torch.einsum("ijbh,jbhd->ibhd", attn, value) + + self.attn = attn.detach() + + x = x.reshape(seq_len, batch_size, -1) + + return self.output(x) diff --git a/transformers/relative_mha.py b/transformers/relative_mha.py new file mode 100644 index 00000000..c7a2d84f --- /dev/null +++ b/transformers/relative_mha.py @@ -0,0 +1,66 @@ +import copy + +import torch +from torch import nn + +from labml.helpers.pytorch.module import Module +from transformers.mha import MultiHeadAttention + + +class PrepareForMultiHeadAttention(Module): + def __init__(self, d_model: int, heads: int, d_k: int): + super().__init__() + self.linear = nn.Linear(d_model, heads * d_k, bias=False) + self.heads = heads + self.d_k = d_k + + def __call__(self, x: torch.Tensor): + seq_len, batch_size, _ = x.shape + + x = self.linear(x) + x = x.view(seq_len, batch_size, self.heads, self.d_k) + + return x + + +class RelativeMultiHeadAttention(MultiHeadAttention): + @staticmethod + def _rel_shift(x: torch.Tensor): + zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]), + device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=1) + + x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) + + x = x_padded[1:].view_as(x) + + ones = torch.ones((x.size(0), x.size(1)), device=x.device) + lower_triangle = torch.tril(ones, x.size(1) - x.size(0)) + x = x * lower_triangle[:, :, None, None] + + return x + + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + super().__init__(heads, d_model, dropout_prob) + self.max_key_len = 2 ** 12 + + self.key_pos_embeddings = nn.Parameter( + torch.zeros((self.max_key_len, heads, self.d_k)), + requires_grad=True) + self.query_pos_bias = nn.Parameter( + torch.zeros((heads, self.d_k)), + requires_grad=True) + self.key_pos_bias = nn.Parameter( + torch.zeros((self.max_key_len, heads)), + requires_grad=True) + + def get_scores(self, query: torch.Tensor, + key: torch.Tensor, ): + key_len = key.shape[0] + + ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key) + b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:]) + d = self.key_pos_bias[None, -key_len:, None, :] + bd = self._rel_shift(b + d) + + return ac + bd diff --git a/transformers/utils.py b/transformers/utils.py new file mode 100644 index 00000000..4f59b341 --- /dev/null +++ b/transformers/utils.py @@ -0,0 +1,9 @@ +import copy + +from torch import nn + +from labml.helpers.pytorch.module import Module + + +def clone_module_list(module: Module, n: int): + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])