mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	relative transformer fix for encoder decoder
This commit is contained in:
		| @ -9,9 +9,9 @@ from labml.helpers.pytorch.module import Module | |||||||
|  |  | ||||||
|  |  | ||||||
| class PrepareForMultiHeadAttention(Module): | class PrepareForMultiHeadAttention(Module): | ||||||
|     def __init__(self, d_model: int, heads: int, d_k: int): |     def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.linear = nn.Linear(d_model, heads * d_k) |         self.linear = nn.Linear(d_model, heads * d_k, bias=bias) | ||||||
|         self.heads = heads |         self.heads = heads | ||||||
|         self.d_k = d_k |         self.d_k = d_k | ||||||
|  |  | ||||||
| @ -25,21 +25,19 @@ class PrepareForMultiHeadAttention(Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class MultiHeadAttention(Module): | class MultiHeadAttention(Module): | ||||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): |     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias=True): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         # We assume d_v always equals d_k |  | ||||||
|         self.d_k = d_model // heads |         self.d_k = d_model // heads | ||||||
|         self.heads = heads |         self.heads = heads | ||||||
|         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k) |         self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||||
|         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k) |         self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||||
|         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k) |         self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias) | ||||||
|         self.output = nn.Linear(d_model, d_model) |         self.output = nn.Linear(d_model, d_model) | ||||||
|         self.attn = None |         self.attn = None | ||||||
|         self.dropout = nn.Dropout(dropout_prob) |         self.dropout = nn.Dropout(dropout_prob) | ||||||
|         self.scale = 1 / math.sqrt(self.d_k) |         self.scale = 1 / math.sqrt(self.d_k) | ||||||
|  |  | ||||||
|     def get_scores(self, query: torch.Tensor, |     def get_scores(self, query: torch.Tensor, key: torch.Tensor, ): | ||||||
|                    key: torch.Tensor, ): |  | ||||||
|         return torch.einsum('ibhd,jbhd->ijbh', query, key) |         return torch.einsum('ibhd,jbhd->ijbh', query, key) | ||||||
|  |  | ||||||
|     def __call__(self, *, |     def __call__(self, *, | ||||||
|  | |||||||
| @ -1,66 +1,58 @@ | |||||||
| import copy | """ | ||||||
|  | Implementation of "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" | ||||||
|  | https://arxiv.org/abs/1901.02860 | ||||||
|  | """ | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| from torch import nn | from torch import nn | ||||||
|  |  | ||||||
| from labml.helpers.pytorch.module import Module | from labml.helpers.pytorch.module import Module | ||||||
|  | from labml.logger import inspect | ||||||
| from transformers.mha import MultiHeadAttention | from transformers.mha import MultiHeadAttention | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrepareForMultiHeadAttention(Module): | def relative_shift(x: torch.Tensor): | ||||||
|     def __init__(self, d_model: int, heads: int, d_k: int): |     zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:]) | ||||||
|         super().__init__() |     x_padded = torch.cat([x, zero_pad], dim=1) | ||||||
|         self.linear = nn.Linear(d_model, heads * d_k, bias=False) |  | ||||||
|         self.heads = heads |  | ||||||
|         self.d_k = d_k |  | ||||||
|  |  | ||||||
|     def __call__(self, x: torch.Tensor): |     x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) | ||||||
|         seq_len, batch_size, _ = x.shape |  | ||||||
|  |  | ||||||
|         x = self.linear(x) |     x = x_padded[:-1].view_as(x) | ||||||
|         x = x.view(seq_len, batch_size, self.heads, self.d_k) |  | ||||||
|  |  | ||||||
|     return x |     return x | ||||||
|  |  | ||||||
|  |  | ||||||
| class RelativeMultiHeadAttention(MultiHeadAttention): | class RelativeMultiHeadAttention(MultiHeadAttention): | ||||||
|     @staticmethod |  | ||||||
|     def _rel_shift(x: torch.Tensor): |  | ||||||
|         zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]), |  | ||||||
|                                device=x.device, dtype=x.dtype) |  | ||||||
|         x_padded = torch.cat([zero_pad, x], dim=1) |  | ||||||
|  |  | ||||||
|         x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:]) |  | ||||||
|  |  | ||||||
|         x = x_padded[1:].view_as(x) |  | ||||||
|  |  | ||||||
|         ones = torch.ones((x.size(0), x.size(1)), device=x.device) |  | ||||||
|         lower_triangle = torch.tril(ones, x.size(1) - x.size(0)) |  | ||||||
|         x = x * lower_triangle[:, :, None, None] |  | ||||||
|  |  | ||||||
|         return x |  | ||||||
|  |  | ||||||
|     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): |     def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1): | ||||||
|         super().__init__(heads, d_model, dropout_prob) |         super().__init__(heads, d_model, dropout_prob, False) | ||||||
|         self.max_key_len = 2 ** 12 |         self.P = 2 ** 12 | ||||||
|  |  | ||||||
|         self.key_pos_embeddings = nn.Parameter( |         self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P * 2, heads, self.d_k)), requires_grad=True) | ||||||
|             torch.zeros((self.max_key_len, heads, self.d_k)), |         self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True) | ||||||
|             requires_grad=True) |         self.key_pos_bias = nn.Parameter(torch.zeros((self.P * 2, heads)), requires_grad=True) | ||||||
|         self.query_pos_bias = nn.Parameter( |  | ||||||
|             torch.zeros((heads, self.d_k)), |  | ||||||
|             requires_grad=True) |  | ||||||
|         self.key_pos_bias = nn.Parameter( |  | ||||||
|             torch.zeros((self.max_key_len, heads)), |  | ||||||
|             requires_grad=True) |  | ||||||
|  |  | ||||||
|     def get_scores(self, query: torch.Tensor, |     def get_scores(self, query: torch.Tensor, key: torch.Tensor): | ||||||
|                    key: torch.Tensor, ): |         key_pos_emb = self.key_pos_embeddings[self.P - query.shape[0]:self.P + key.shape[0]] | ||||||
|         key_len = key.shape[0] |         key_pos_bias = self.key_pos_bias[self.P - query.shape[0]:self.P + key.shape[0]] | ||||||
|  |  | ||||||
|         ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key) |         ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key) | ||||||
|         b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:]) |         b = torch.einsum('ibhd,jhd->ijbh', query, key_pos_emb) | ||||||
|         d = self.key_pos_bias[None, -key_len:, None, :] |         d = key_pos_bias[None, :, None, :] | ||||||
|         bd = self._rel_shift(b + d) |         bd = relative_shift(b + d) | ||||||
|  |         bd = bd[:, -key.shape[0]:] | ||||||
|  |  | ||||||
|         return ac + bd |         return ac + bd | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _test_relative_shift(): | ||||||
|  |     x = torch.arange(1, 6)[None, :, None, None].repeat(5, 1, 1, 1) | ||||||
|  |     inspect(x[:, :, 0, 0]) | ||||||
|  |     inspect(relative_shift(x)[:, :, 0, 0]) | ||||||
|  |  | ||||||
|  |     x = torch.arange(1, 6)[None, :, None, None].repeat(3, 1, 1, 1) | ||||||
|  |     inspect(x[:, :, 0, 0]) | ||||||
|  |     inspect(relative_shift(x)[:, :, 0, 0]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     _test_relative_shift() | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri