mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
relative transformer fix for encoder decoder
This commit is contained in:
@ -9,9 +9,9 @@ from labml.helpers.pytorch.module import Module
|
|||||||
|
|
||||||
|
|
||||||
class PrepareForMultiHeadAttention(Module):
|
class PrepareForMultiHeadAttention(Module):
|
||||||
def __init__(self, d_model: int, heads: int, d_k: int):
|
def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.linear = nn.Linear(d_model, heads * d_k)
|
self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.d_k = d_k
|
self.d_k = d_k
|
||||||
|
|
||||||
@ -25,21 +25,19 @@ class PrepareForMultiHeadAttention(Module):
|
|||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(Module):
|
class MultiHeadAttention(Module):
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# We assume d_v always equals d_k
|
|
||||||
self.d_k = d_model // heads
|
self.d_k = d_model // heads
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
|
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||||
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||||
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias)
|
||||||
self.output = nn.Linear(d_model, d_model)
|
self.output = nn.Linear(d_model, d_model)
|
||||||
self.attn = None
|
self.attn = None
|
||||||
self.dropout = nn.Dropout(dropout_prob)
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
self.scale = 1 / math.sqrt(self.d_k)
|
self.scale = 1 / math.sqrt(self.d_k)
|
||||||
|
|
||||||
def get_scores(self, query: torch.Tensor,
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor, ):
|
||||||
key: torch.Tensor, ):
|
|
||||||
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
||||||
|
|
||||||
def __call__(self, *,
|
def __call__(self, *,
|
||||||
|
|||||||
@ -1,66 +1,58 @@
|
|||||||
import copy
|
"""
|
||||||
|
Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||||
|
https://arxiv.org/abs/1901.02860
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml.helpers.pytorch.module import Module
|
||||||
|
from labml.logger import inspect
|
||||||
from transformers.mha import MultiHeadAttention
|
from transformers.mha import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
class PrepareForMultiHeadAttention(Module):
|
def relative_shift(x: torch.Tensor):
|
||||||
def __init__(self, d_model: int, heads: int, d_k: int):
|
zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
|
||||||
super().__init__()
|
x_padded = torch.cat([x, zero_pad], dim=1)
|
||||||
self.linear = nn.Linear(d_model, heads * d_k, bias=False)
|
|
||||||
self.heads = heads
|
|
||||||
self.d_k = d_k
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor):
|
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
|
||||||
seq_len, batch_size, _ = x.shape
|
|
||||||
|
|
||||||
x = self.linear(x)
|
x = x_padded[:-1].view_as(x)
|
||||||
x = x.view(seq_len, batch_size, self.heads, self.d_k)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RelativeMultiHeadAttention(MultiHeadAttention):
|
class RelativeMultiHeadAttention(MultiHeadAttention):
|
||||||
@staticmethod
|
|
||||||
def _rel_shift(x: torch.Tensor):
|
|
||||||
zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]),
|
|
||||||
device=x.device, dtype=x.dtype)
|
|
||||||
x_padded = torch.cat([zero_pad, x], dim=1)
|
|
||||||
|
|
||||||
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
|
|
||||||
|
|
||||||
x = x_padded[1:].view_as(x)
|
|
||||||
|
|
||||||
ones = torch.ones((x.size(0), x.size(1)), device=x.device)
|
|
||||||
lower_triangle = torch.tril(ones, x.size(1) - x.size(0))
|
|
||||||
x = x * lower_triangle[:, :, None, None]
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||||
super().__init__(heads, d_model, dropout_prob)
|
super().__init__(heads, d_model, dropout_prob, False)
|
||||||
self.max_key_len = 2 ** 12
|
self.P = 2 ** 12
|
||||||
|
|
||||||
self.key_pos_embeddings = nn.Parameter(
|
self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True)
|
||||||
torch.zeros((self.max_key_len, heads, self.d_k)),
|
self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
|
||||||
requires_grad=True)
|
self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True)
|
||||||
self.query_pos_bias = nn.Parameter(
|
|
||||||
torch.zeros((heads, self.d_k)),
|
|
||||||
requires_grad=True)
|
|
||||||
self.key_pos_bias = nn.Parameter(
|
|
||||||
torch.zeros((self.max_key_len, heads)),
|
|
||||||
requires_grad=True)
|
|
||||||
|
|
||||||
def get_scores(self, query: torch.Tensor,
|
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
|
||||||
key: torch.Tensor, ):
|
key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||||
key_len = key.shape[0]
|
key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]]
|
||||||
|
|
||||||
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
|
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
|
||||||
b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:])
|
b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb)
|
||||||
d = self.key_pos_bias[None, -key_len:, None, :]
|
d = key_pos_bias[None, :, None, :]
|
||||||
bd = self._rel_shift(b + d)
|
bd = relative_shift(b + d)
|
||||||
|
bd = bd[:, -key.shape[0]:]
|
||||||
|
|
||||||
return ac + bd
|
return ac + bd
|
||||||
|
|
||||||
|
|
||||||
|
def _test_relative_shift():
|
||||||
|
x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1)
|
||||||
|
inspect(x[:, :, 0, 0])
|
||||||
|
inspect(relative_shift(x)[:, :, 0, 0])
|
||||||
|
|
||||||
|
x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1)
|
||||||
|
inspect(x[:, :, 0, 0])
|
||||||
|
inspect(relative_shift(x)[:, :, 0, 0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
_test_relative_shift()
|
||||||
|
|||||||
Reference in New Issue
Block a user