diff --git a/transformers/mha.py b/transformers/mha.py index 933e5181..442e9dd8 100644 --- a/transformers/mha.py +++ b/transformers/mha.py @@ -9,9 +9,9 @@ from labml.helpers.pytorch.module import Module class PrepareForMultiHeadAttention(Module): - def __init__(self, d_model: int, heads: int, d_k: int): + def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): super().__init__() - self.linear = nn.Linear(d_model, heads * d_k) + self.linear = nn.Linear(d_model, heads * d_k, bias=bias) self.heads = heads self.d_k = d_k @@ -25,21 +25,19 @@ class PrepareForMultiHeadAttention(Module): class MultiHeadAttention(Module): - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): + def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias=True): super().__init__() - # We assume d_v always equals d_k self.d_k = d_model // heads self.heads = heads - self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k) - self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k) - self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k) + self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) + self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) + self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) self.output = nn.Linear(d_model, d_model) self.attn = None self.dropout = nn.Dropout(dropout_prob) self.scale = 1 / math.sqrt(self.d_k) - def get_scores(self, query: torch.Tensor, - key: torch.Tensor, ): + def get_scores(self, query: torch.Tensor, key: torch.Tensor, ): return torch.einsum('ibhd,jbhd->ijbh', query, key) def __call__(self, *, diff --git a/transformers/relative_mha.py b/transformers/relative_mha.py index c7a2d84f..41a14122 100644 --- a/transformers/relative_mha.py +++ b/transformers/relative_mha.py @@ -1,66 +1,58 @@ -import copy +""" +Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" +https://arxiv.org/abs/1901.02860 +""" import torch from torch import nn from labml.helpers.pytorch.module import Module +from labml.logger import inspect from transformers.mha import MultiHeadAttention -class PrepareForMultiHeadAttention(Module): - def __init__(self, d_model: int, heads: int, d_k: int): - super().__init__() - self.linear = nn.Linear(d_model, heads * d_k, bias=False) - self.heads = heads - self.d_k = d_k +def relative_shift(x: torch.Tensor): + zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:]) + x_padded = torch.cat([x, zero_pad], dim=1) - def __call__(self, x: torch.Tensor): - seq_len, batch_size, _ = x.shape + x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) - x = self.linear(x) - x = x.view(seq_len, batch_size, self.heads, self.d_k) + x = x_padded[:-1].view_as(x) - return x + return x class RelativeMultiHeadAttention(MultiHeadAttention): - @staticmethod - def _rel_shift(x: torch.Tensor): - zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]), - device=x.device, dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=1) - - x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) - - x = x_padded[1:].view_as(x) - - ones = torch.ones((x.size(0), x.size(1)), device=x.device) - lower_triangle = torch.tril(ones, x.size(1) - x.size(0)) - x = x * lower_triangle[:, :, None, None] - - return x - def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): - super().__init__(heads, d_model, dropout_prob) - self.max_key_len = 2 ** 12 + super().__init__(heads, d_model, dropout_prob, False) + self.P = 2 ** 12 - self.key_pos_embeddings = nn.Parameter( - torch.zeros((self.max_key_len, heads, self.d_k)), - requires_grad=True) - self.query_pos_bias = nn.Parameter( - torch.zeros((heads, self.d_k)), - requires_grad=True) - self.key_pos_bias = nn.Parameter( - torch.zeros((self.max_key_len, heads)), - requires_grad=True) + self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) + self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) + self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) - def get_scores(self, query: torch.Tensor, - key: torch.Tensor, ): - key_len = key.shape[0] + def get_scores(self, query: torch.Tensor, key: torch.Tensor): + key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] + key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key) - b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:]) - d = self.key_pos_bias[None, -key_len:, None, :] - bd = self._rel_shift(b + d) + b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) + d = key_pos_bias[None, :, None, :] + bd = relative_shift(b + d) + bd = bd[:, -key.shape[0]:] return ac + bd + + +def _test_relative_shift(): + x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1) + inspect(x[:, :, 0, 0]) + inspect(relative_shift(x)[:, :, 0, 0]) + + x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1) + inspect(x[:, :, 0, 0]) + inspect(relative_shift(x)[:, :, 0, 0]) + + +if __name__ == '__main__': + _test_relative_shift()