mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	relative transformer fix for encoder decoder
This commit is contained in:
		| @ -9,9 +9,9 @@ from labml.helpers.pytorch.module import Module | ||||
|  | ||||
|  | ||||
| class PrepareForMultiHeadAttention(Module): | ||||
|     def __init__(self, d_model: int, heads: int, d_k: int): | ||||
|     def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): | ||||
|         super().__init__() | ||||
|         self.linear = nn.Linear(d_model, heads * d_k) | ||||
|         self.linear = nn.Linear(d_model, heads * d_k, bias=bias) | ||||
|         self.heads = heads | ||||
|         self.d_k = d_k | ||||
|  | ||||
| @ -25,21 +25,19 @@ class PrepareForMultiHeadAttention(Module): | ||||
|  | ||||
|  | ||||
| class MultiHeadAttention(Module): | ||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): | ||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias=True): | ||||
|         super().__init__() | ||||
|         # We assume d_v always equals d_k | ||||
|         self.d_k = d_model // heads | ||||
|         self.heads = heads | ||||
|         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k) | ||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k) | ||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k) | ||||
|         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||
|         self.output = nn.Linear(d_model, d_model) | ||||
|         self.attn = None | ||||
|         self.dropout = nn.Dropout(dropout_prob) | ||||
|         self.scale = 1 / math.sqrt(self.d_k) | ||||
|  | ||||
|     def get_scores(self, query: torch.Tensor, | ||||
|                    key: torch.Tensor, ): | ||||
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor, ): | ||||
|         return torch.einsum('ibhd,jbhd->ijbh', query, key) | ||||
|  | ||||
|     def __call__(self, *, | ||||
|  | ||||
| @ -1,66 +1,58 @@ | ||||
| import copy | ||||
| """ | ||||
| Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" | ||||
| https://arxiv.org/abs/1901.02860 | ||||
| """ | ||||
|  | ||||
| import torch | ||||
| from torch import nn | ||||
|  | ||||
| from labml.helpers.pytorch.module import Module | ||||
| from labml.logger import inspect | ||||
| from transformers.mha import MultiHeadAttention | ||||
|  | ||||
|  | ||||
| class PrepareForMultiHeadAttention(Module): | ||||
|     def __init__(self, d_model: int, heads: int, d_k: int): | ||||
|         super().__init__() | ||||
|         self.linear = nn.Linear(d_model, heads * d_k, bias=False) | ||||
|         self.heads = heads | ||||
|         self.d_k = d_k | ||||
| def relative_shift(x: torch.Tensor): | ||||
|     zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:]) | ||||
|     x_padded = torch.cat([x, zero_pad], dim=1) | ||||
|  | ||||
|     def __call__(self, x: torch.Tensor): | ||||
|         seq_len, batch_size, _ = x.shape | ||||
|     x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) | ||||
|  | ||||
|         x = self.linear(x) | ||||
|         x = x.view(seq_len, batch_size, self.heads, self.d_k) | ||||
|     x = x_padded[:-1].view_as(x) | ||||
|  | ||||
|         return x | ||||
|     return x | ||||
|  | ||||
|  | ||||
| class RelativeMultiHeadAttention(MultiHeadAttention): | ||||
|     @staticmethod | ||||
|     def _rel_shift(x: torch.Tensor): | ||||
|         zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]), | ||||
|                                device=x.device, dtype=x.dtype) | ||||
|         x_padded = torch.cat([zero_pad, x], dim=1) | ||||
|  | ||||
|         x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) | ||||
|  | ||||
|         x = x_padded[1:].view_as(x) | ||||
|  | ||||
|         ones = torch.ones((x.size(0), x.size(1)), device=x.device) | ||||
|         lower_triangle = torch.tril(ones, x.size(1) - x.size(0)) | ||||
|         x = x * lower_triangle[:, :, None, None] | ||||
|  | ||||
|         return x | ||||
|  | ||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): | ||||
|         super().__init__(heads, d_model, dropout_prob) | ||||
|         self.max_key_len = 2 ** 12 | ||||
|         super().__init__(heads, d_model, dropout_prob, False) | ||||
|         self.P = 2 ** 12 | ||||
|  | ||||
|         self.key_pos_embeddings = nn.Parameter( | ||||
|             torch.zeros((self.max_key_len, heads, self.d_k)), | ||||
|             requires_grad=True) | ||||
|         self.query_pos_bias = nn.Parameter( | ||||
|             torch.zeros((heads, self.d_k)), | ||||
|             requires_grad=True) | ||||
|         self.key_pos_bias = nn.Parameter( | ||||
|             torch.zeros((self.max_key_len, heads)), | ||||
|             requires_grad=True) | ||||
|         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) | ||||
|         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||
|         self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) | ||||
|  | ||||
|     def get_scores(self, query: torch.Tensor, | ||||
|                    key: torch.Tensor, ): | ||||
|         key_len = key.shape[0] | ||||
|     def get_scores(self, query: torch.Tensor, key: torch.Tensor): | ||||
|         key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] | ||||
|         key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] | ||||
|  | ||||
|         ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key) | ||||
|         b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:]) | ||||
|         d = self.key_pos_bias[None, -key_len:, None, :] | ||||
|         bd = self._rel_shift(b + d) | ||||
|         b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) | ||||
|         d = key_pos_bias[None, :, None, :] | ||||
|         bd = relative_shift(b + d) | ||||
|         bd = bd[:, -key.shape[0]:] | ||||
|  | ||||
|         return ac + bd | ||||
|  | ||||
|  | ||||
| def _test_relative_shift(): | ||||
|     x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1) | ||||
|     inspect(x[:, :, 0, 0]) | ||||
|     inspect(relative_shift(x)[:, :, 0, 0]) | ||||
|  | ||||
|     x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1) | ||||
|     inspect(x[:, :, 0, 0]) | ||||
|     inspect(relative_shift(x)[:, :, 0, 0]) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     _test_relative_shift() | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri