Files
Varuna Jayasiri 6924f4580c lab_helpers
2020-09-01 08:05:08 +05:30

59 lines
1.9 KiB
Python

"""
Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
https://arxiv.org/abs/1901.02860
"""
import torch
from torch import nn
from labml_helpers.module import Module
from labml.logger import inspect
from transformers.mha import MultiHeadAttention
def relative_shift(x: torch.Tensor):
zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
x_padded = torch.cat([x, zero_pad], dim=1)
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
x = x_padded[:-1].view_as(x)
return x
class RelativeMultiHeadAttention(MultiHeadAttention):
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
super().__init__(heads, d_model, dropout_prob, False)
self.P = 2 ** 12
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
d = key_pos_bias[None, :, None, :]
bd = relative_shift(b + d)
bd = bd[:, -key.shape[0]:]
return ac + bd
def _test_relative_shift():
x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(relative_shift(x)[:, :, 0, 0])
x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
inspect(x[:, :, 0, 0])
inspect(relative_shift(x)[:, :, 0, 0])
if __name__ == '__main__':
_test_relative_shift()