mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 20:28:41 +08:00
fix dependecies and include relative attention
This commit is contained in:
@ -5,11 +5,11 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from common.models import clone_module_list
|
from labml.configs import BaseConfigs, option, calculate
|
||||||
from labml.configs import BaseConfigs, option
|
|
||||||
from labml.helpers.pytorch.module import Module
|
from labml.helpers.pytorch.module import Module
|
||||||
from transformer.models.multi_headed_attention import MultiHeadedAttention
|
from transformers.mha import MultiHeadAttention
|
||||||
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
|
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
|
||||||
|
from transformers.utils import clone_module_list
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithPositionalEncoding(Module):
|
class EmbeddingsWithPositionalEncoding(Module):
|
||||||
@ -53,8 +53,8 @@ class FeedForward(Module):
|
|||||||
class TransformerLayer(Module):
|
class TransformerLayer(Module):
|
||||||
def __init__(self, *,
|
def __init__(self, *,
|
||||||
d_model: int,
|
d_model: int,
|
||||||
self_attn: MultiHeadedAttention,
|
self_attn: MultiHeadAttention,
|
||||||
src_attn: MultiHeadedAttention = None,
|
src_attn: MultiHeadAttention = None,
|
||||||
feed_forward: FeedForward,
|
feed_forward: FeedForward,
|
||||||
dropout_prob: float):
|
dropout_prob: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -161,9 +161,9 @@ class TransformerConfigs(BaseConfigs):
|
|||||||
n_src_vocab: int
|
n_src_vocab: int
|
||||||
n_tgt_vocab: int
|
n_tgt_vocab: int
|
||||||
|
|
||||||
encoder_attn: MultiHeadedAttention = 'mha'
|
encoder_attn: MultiHeadAttention = 'mha'
|
||||||
decoder_attn: MultiHeadedAttention = 'mha'
|
decoder_attn: MultiHeadAttention = 'mha'
|
||||||
decoder_mem_attn: MultiHeadedAttention = 'mha'
|
decoder_mem_attn: MultiHeadAttention = 'mha'
|
||||||
feed_forward: FeedForward
|
feed_forward: FeedForward
|
||||||
|
|
||||||
encoder_layer: TransformerLayer = 'normal'
|
encoder_layer: TransformerLayer = 'normal'
|
||||||
@ -185,19 +185,25 @@ def _feed_forward(c: TransformerConfigs):
|
|||||||
return FeedForward(c.d_model, c.d_ff, c.dropout)
|
return FeedForward(c.d_model, c.d_ff, c.dropout)
|
||||||
|
|
||||||
|
|
||||||
@option(TransformerConfigs.encoder_attn, 'mha')
|
### MHA
|
||||||
def _encoder_mha(c: TransformerConfigs):
|
def _mha(c: TransformerConfigs):
|
||||||
return MultiHeadedAttention(c.n_heads, c.d_model)
|
return MultiHeadAttention(c.n_heads, c.d_model)
|
||||||
|
|
||||||
|
|
||||||
@option(TransformerConfigs.decoder_attn, 'mha')
|
calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
|
||||||
def _decoder_mha(c: TransformerConfigs):
|
calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
|
||||||
return MultiHeadedAttention(c.n_heads, c.d_model)
|
calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
|
||||||
|
|
||||||
|
|
||||||
@option(TransformerConfigs.decoder_mem_attn, 'mha')
|
### Relative MHA
|
||||||
def _decoder_mem_mha(c: TransformerConfigs):
|
def _relative_mha(c: TransformerConfigs):
|
||||||
return MultiHeadedAttention(c.n_heads, c.d_model)
|
from transformers.relative_mha import RelativeMultiHeadAttention
|
||||||
|
return RelativeMultiHeadAttention(c.n_heads, c.d_model)
|
||||||
|
|
||||||
|
|
||||||
|
calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
|
||||||
|
calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
|
||||||
|
calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
|
||||||
|
|
||||||
|
|
||||||
@option(TransformerConfigs.encoder_layer, 'normal')
|
@option(TransformerConfigs.encoder_layer, 'normal')
|
||||||
@ -229,6 +235,7 @@ def _generator(c: TransformerConfigs):
|
|||||||
return Generator(c.n_tgt_vocab, c.d_model)
|
return Generator(c.n_tgt_vocab, c.d_model)
|
||||||
|
|
||||||
|
|
||||||
|
### Positional Embeddings
|
||||||
@option(TransformerConfigs.src_embed, 'fixed_pos')
|
@option(TransformerConfigs.src_embed, 'fixed_pos')
|
||||||
def _src_embed_with_positional(c: TransformerConfigs):
|
def _src_embed_with_positional(c: TransformerConfigs):
|
||||||
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
|
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
|
||||||
@ -239,6 +246,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
|
|||||||
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
||||||
|
|
||||||
|
|
||||||
|
### Learned Positional Embeddings
|
||||||
@option(TransformerConfigs.src_embed, 'learned_pos')
|
@option(TransformerConfigs.src_embed, 'learned_pos')
|
||||||
def _src_embed_with_learned_positional(c: TransformerConfigs):
|
def _src_embed_with_learned_positional(c: TransformerConfigs):
|
||||||
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
|
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
|
||||||
@ -249,6 +257,16 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
|
|||||||
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
||||||
|
|
||||||
|
|
||||||
|
### No Positional Embeddings
|
||||||
|
@option(TransformerConfigs.src_embed, 'no_pos')
|
||||||
|
def _src_embed_without_positional(c: TransformerConfigs):
|
||||||
|
return nn.Embedding(c.n_src_vocab, c.d_model)
|
||||||
|
|
||||||
|
|
||||||
|
@option(TransformerConfigs.tgt_embed, 'no_pos')
|
||||||
|
def _tgt_embed_without_positional(c: TransformerConfigs):
|
||||||
|
return nn.Embedding(c.n_tgt_vocab, c.d_model)
|
||||||
|
|
||||||
|
|
||||||
@option(TransformerConfigs.encoder_decoder, 'normal')
|
@option(TransformerConfigs.encoder_decoder, 'normal')
|
||||||
def _encoder_decoder(c: TransformerConfigs):
|
def _encoder_decoder(c: TransformerConfigs):
|
||||||
|
|||||||
78
transformers/mha.py
Normal file
78
transformers/mha.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from labml.helpers.pytorch.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareForMultiHeadAttention(Module):
|
||||||
|
def __init__(self, d_model: int, heads: int, d_k: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(d_model, heads * d_k)
|
||||||
|
self.heads = heads
|
||||||
|
self.d_k = d_k
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
seq_len, batch_size, _ = x.shape
|
||||||
|
|
||||||
|
x = self.linear(x)
|
||||||
|
x = x.view(seq_len, batch_size, self.heads, self.d_k)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(Module):
|
||||||
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
# We assume d_v always equals d_k
|
||||||
|
self.d_k = d_model // heads
|
||||||
|
self.heads = heads
|
||||||
|
self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
|
||||||
|
self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
|
||||||
|
self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k)
|
||||||
|
self.output = nn.Linear(d_model, d_model)
|
||||||
|
self.attn = None
|
||||||
|
self.dropout = nn.Dropout(dropout_prob)
|
||||||
|
self.scale = 1 / math.sqrt(self.d_k)
|
||||||
|
|
||||||
|
def get_scores(self, query: torch.Tensor,
|
||||||
|
key: torch.Tensor, ):
|
||||||
|
return torch.einsum('ibhd,jbhd->ijbh', query, key)
|
||||||
|
|
||||||
|
def __call__(self, *,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None):
|
||||||
|
seq_len, batch_size, *_ = query.shape
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
# mask = ijb
|
||||||
|
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
|
||||||
|
# Same mask applied to all h heads.
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
query = self.query(query)
|
||||||
|
key = self.key(key)
|
||||||
|
value = self.value(value)
|
||||||
|
|
||||||
|
scores = self.get_scores(query, key)
|
||||||
|
|
||||||
|
scores *= self.scale
|
||||||
|
if mask is not None:
|
||||||
|
# mask = ijbh
|
||||||
|
assert mask.shape[0] == 1 or mask.shape[0] == mask.shape[1]
|
||||||
|
scores = scores.masked_fill(mask == 0, -1e9)
|
||||||
|
attn = F.softmax(scores, dim=1)
|
||||||
|
attn = self.dropout(attn)
|
||||||
|
|
||||||
|
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
|
||||||
|
|
||||||
|
self.attn = attn.detach()
|
||||||
|
|
||||||
|
x = x.reshape(seq_len, batch_size, -1)
|
||||||
|
|
||||||
|
return self.output(x)
|
||||||
66
transformers/relative_mha.py
Normal file
66
transformers/relative_mha.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml.helpers.pytorch.module import Module
|
||||||
|
from transformers.mha import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareForMultiHeadAttention(Module):
|
||||||
|
def __init__(self, d_model: int, heads: int, d_k: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(d_model, heads * d_k, bias=False)
|
||||||
|
self.heads = heads
|
||||||
|
self.d_k = d_k
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
seq_len, batch_size, _ = x.shape
|
||||||
|
|
||||||
|
x = self.linear(x)
|
||||||
|
x = x.view(seq_len, batch_size, self.heads, self.d_k)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RelativeMultiHeadAttention(MultiHeadAttention):
|
||||||
|
@staticmethod
|
||||||
|
def _rel_shift(x: torch.Tensor):
|
||||||
|
zero_pad = torch.zeros((x.shape[0], 1, *x.shape[2:]),
|
||||||
|
device=x.device, dtype=x.dtype)
|
||||||
|
x_padded = torch.cat([zero_pad, x], dim=1)
|
||||||
|
|
||||||
|
x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
|
||||||
|
|
||||||
|
x = x_padded[1:].view_as(x)
|
||||||
|
|
||||||
|
ones = torch.ones((x.size(0), x.size(1)), device=x.device)
|
||||||
|
lower_triangle = torch.tril(ones, x.size(1) - x.size(0))
|
||||||
|
x = x * lower_triangle[:, :, None, None]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
|
||||||
|
super().__init__(heads, d_model, dropout_prob)
|
||||||
|
self.max_key_len = 2 ** 12
|
||||||
|
|
||||||
|
self.key_pos_embeddings = nn.Parameter(
|
||||||
|
torch.zeros((self.max_key_len, heads, self.d_k)),
|
||||||
|
requires_grad=True)
|
||||||
|
self.query_pos_bias = nn.Parameter(
|
||||||
|
torch.zeros((heads, self.d_k)),
|
||||||
|
requires_grad=True)
|
||||||
|
self.key_pos_bias = nn.Parameter(
|
||||||
|
torch.zeros((self.max_key_len, heads)),
|
||||||
|
requires_grad=True)
|
||||||
|
|
||||||
|
def get_scores(self, query: torch.Tensor,
|
||||||
|
key: torch.Tensor, ):
|
||||||
|
key_len = key.shape[0]
|
||||||
|
|
||||||
|
ac = torch.einsum('ibhd,jbhd->ijbh', query + self.query_pos_bias[None, None, :, :], key)
|
||||||
|
b = torch.einsum('ibhd,jhd->ijbh', query, self.key_pos_embeddings[-key_len:])
|
||||||
|
d = self.key_pos_bias[None, -key_len:, None, :]
|
||||||
|
bd = self._rel_shift(b + d)
|
||||||
|
|
||||||
|
return ac + bd
|
||||||
9
transformers/utils.py
Normal file
9
transformers/utils.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from labml.helpers.pytorch.module import Module
|
||||||
|
|
||||||
|
|
||||||
|
def clone_module_list(module: Module, n: int):
|
||||||
|
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||||
Reference in New Issue
Block a user