Files
2020-08-25 15:54:47 +05:30

67 lines
2.1 KiB
Python

import copy
import torch
from torch import nn
from labml.helpers.pytorch.module import Module
from transformers.mha import MultiHeadAttention
class PrepareForMultiHeadAttention(Module):
def __init__(self, d_model: int, heads: int, d_k: int):
super().__init__()
self.linear = nn.Linear(d_model, heads * d_k, bias=False)
self.heads = heads
self.d_k = d_k
def __call__(self, x: torch.Tensor):
seq_len, batch_size, _ = x.shape
x = self.linear(x)
x = x.view(seq_len, batch_size, self.heads, self.d_k)
return x
class RelativeMultiHeadAttention(MultiHeadAttention):
@staticmethod
def _rel_shift(x: torch.Tensor):
zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
x = x_padded[1:].view_as(x)
ones = torch.ones((x.size(0), x.size(1)), device=x.device)
lower_triangle = torch.tril(ones, x.size(1) - x.size(0))
x = x * lower_triangle[:, :, None, None]
return x
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
super().__init__(heads, d_model, dropout_prob)
self.max_key_len = 2 ** 12
self.key_pos_embeddings = nn.Parameter(
torch.zeros((self.max_key_len, heads, self.d_k)),
requires_grad=True)
self.query_pos_bias = nn.Parameter(
torch.zeros((heads, self.d_k)),
requires_grad=True)
self.key_pos_bias = nn.Parameter(
torch.zeros((self.max_key_len, heads)),
requires_grad=True)
def get_scores(self, query: torch.Tensor,
key: torch.Tensor, ):
key_len = key.shape[0]
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:])
d = self.key_pos_bias[None, -key_len:, None, :]
bd = self._rel_shift(b + d)
return ac + bd