relative transformer fix for encoder decoder

This commit is contained in:
Varuna Jayasiri
2020-08-26 08:41:49 +05:30
parent 5ccc095c2a
commit 7db0ced04c
2 changed files with 44 additions and 54 deletions

View File

@ -9,9 +9,9 @@ from labml.helpers.pytorch.module import Module
class PrepareForMultiHeadAttention(Module): class PrepareForMultiHeadAttention(Module):
def __init__(self, d_model: int, heads: int, d_k: int): def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
super().__init__() super().__init__()
self.linear = nn.Linear(d_model, heads * d_k) self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
self.heads = heads self.heads = heads
self.d_k = d_k self.d_k = d_k
@ -25,21 +25,19 @@ class PrepareForMultiHeadAttention(Module):
class MultiHeadAttention(Module): class MultiHeadAttention(Module):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias=True):
super().__init__() super().__init__()
# We assume d_v always equals d_k
self.d_k = d_model // heads self.d_k = d_model // heads
self.heads = heads self.heads = heads
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k) self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k) self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k) self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
self.output = nn.Linear(d_model, d_model) self.output = nn.Linear(d_model, d_model)
self.attn = None self.attn = None
self.dropout = nn.Dropout(dropout_prob) self.dropout = nn.Dropout(dropout_prob)
self.scale = 1 / math.sqrt(self.d_k) self.scale = 1 / math.sqrt(self.d_k)
def get_scores(self, query: torch.Tensor, def get_scores(self, query: torch.Tensor, key: torch.Tensor, ):
key: torch.Tensor, ):
return torch.einsum('ibhd,jbhd->ijbh', query, key) return torch.einsum('ibhd,jbhd->ijbh', query, key)
def __call__(self, *, def __call__(self, *,

View File

@ -1,66 +1,58 @@
import copy """
Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
https://arxiv.org/abs/1901.02860
"""
import torch import torch
from torch import nn from torch import nn
from labml.helpers.pytorch.module import Module from labml.helpers.pytorch.module import Module
from labml.logger import inspect
from transformers.mha import MultiHeadAttention from transformers.mha import MultiHeadAttention
class PrepareForMultiHeadAttention(Module): def relative_shift(x: torch.Tensor):
def __init__(self, d_model: int, heads: int, d_k: int): zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
super().__init__() x_padded = torch.cat([x, zero_pad], dim=1)
self.linear = nn.Linear(d_model, heads * d_k, bias=False)
self.heads = heads
self.d_k = d_k
def __call__(self, x: torch.Tensor): x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
seq_len, batch_size, _ = x.shape
x = self.linear(x) x = x_padded[:-1].view_as(x)
x = x.view(seq_len, batch_size, self.heads, self.d_k)
return x return x
class RelativeMultiHeadAttention(MultiHeadAttention): class RelativeMultiHeadAttention(MultiHeadAttention):
@staticmethod
def _rel_shift(x: torch.Tensor):
zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
x = x_padded[1:].view_as(x)
ones = torch.ones((x.size(0), x.size(1)), device=x.device)
lower_triangle = torch.tril(ones, x.size(1) - x.size(0))
x = x * lower_triangle[:, :, None, None]
return x
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
super().__init__(heads, d_model, dropout_prob) super().__init__(heads, d_model, dropout_prob, False)
self.max_key_len = 2 ** 12 self.P = 2 ** 12
self.key_pos_embeddings = nn.Parameter( self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
torch.zeros((self.max_key_len, heads, self.d_k)), self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
requires_grad=True) self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
self.query_pos_bias = nn.Parameter(
torch.zeros((heads, self.d_k)),
requires_grad=True)
self.key_pos_bias = nn.Parameter(
torch.zeros((self.max_key_len, heads)),
requires_grad=True)
def get_scores(self, query: torch.Tensor, def get_scores(self, query: torch.Tensor, key: torch.Tensor):
key: torch.Tensor, ): key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
key_len = key.shape[0] key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key) ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:]) b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
d = self.key_pos_bias[None, -key_len:, None, :] d = key_pos_bias[None, :, None, :]
bd = self._rel_shift(b + d) bd = relative_shift(b + d)
bd = bd[:, -key.shape[0]:]
return ac + bd return ac + bd
def _test_relative_shift():
x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(relative_shift(x)[:, :, 0, 0])
x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(relative_shift(x)[:, :, 0, 0])
if __name__ == '__main__':
_test_relative_shift()