fix dependecies and include relative attention

This commit is contained in:
Varuna Jayasiri
2020-08-25 15:54:47 +05:30
parent 4ae3d770a9
commit 5ccc095c2a
4 changed files with 188 additions and 17 deletions

View File

@ -5,11 +5,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from common.models import clone_module_list
from labml.configs import BaseConfigs, option
from labml.configs import BaseConfigs, option, calculate
from labml.helpers.pytorch.module import Module
from transformer.models.multi_headed_attention import MultiHeadedAttention
from transformers.mha import MultiHeadAttention
from transformers.positional_encoding import PositionalEncoding, get_positional_encoding
from transformers.utils import clone_module_list
class EmbeddingsWithPositionalEncoding(Module):
@ -53,8 +53,8 @@ class FeedForward(Module):
class TransformerLayer(Module):
def __init__(self, *,
d_model: int,
self_attn: MultiHeadedAttention,
src_attn: MultiHeadedAttention = None,
self_attn: MultiHeadAttention,
src_attn: MultiHeadAttention = None,
feed_forward: FeedForward,
dropout_prob: float):
super().__init__()
@ -161,9 +161,9 @@ class TransformerConfigs(BaseConfigs):
n_src_vocab: int
n_tgt_vocab: int
encoder_attn: MultiHeadedAttention = 'mha'
decoder_attn: MultiHeadedAttention = 'mha'
decoder_mem_attn: MultiHeadedAttention = 'mha'
encoder_attn: MultiHeadAttention = 'mha'
decoder_attn: MultiHeadAttention = 'mha'
decoder_mem_attn: MultiHeadAttention = 'mha'
feed_forward: FeedForward
encoder_layer: TransformerLayer = 'normal'
@ -185,19 +185,25 @@ def _feed_forward(c: TransformerConfigs):
return FeedForward(c.d_model, c.d_ff, c.dropout)
@option(TransformerConfigs.encoder_attn, 'mha')
def _encoder_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
### MHA
def _mha(c: TransformerConfigs):
return MultiHeadAttention(c.n_heads, c.d_model)
@option(TransformerConfigs.decoder_attn, 'mha')
def _decoder_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
@option(TransformerConfigs.decoder_mem_attn, 'mha')
def _decoder_mem_mha(c: TransformerConfigs):
return MultiHeadedAttention(c.n_heads, c.d_model)
### Relative MHA
def _relative_mha(c: TransformerConfigs):
from transformers.relative_mha import RelativeMultiHeadAttention
return RelativeMultiHeadAttention(c.n_heads, c.d_model)
calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
@option(TransformerConfigs.encoder_layer, 'normal')
@ -229,6 +235,7 @@ def _generator(c: TransformerConfigs):
return Generator(c.n_tgt_vocab, c.d_model)
### Positional Embeddings
@option(TransformerConfigs.src_embed, 'fixed_pos')
def _src_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
@ -239,6 +246,7 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
### Learned Positional Embeddings
@option(TransformerConfigs.src_embed, 'learned_pos')
def _src_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
@ -249,6 +257,16 @@ def _tgt_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
### No Positional Embeddings
@option(TransformerConfigs.src_embed, 'no_pos')
def _src_embed_without_positional(c: TransformerConfigs):
return nn.Embedding(c.n_src_vocab, c.d_model)
@option(TransformerConfigs.tgt_embed, 'no_pos')
def _tgt_embed_without_positional(c: TransformerConfigs):
return nn.Embedding(c.n_tgt_vocab, c.d_model)
@option(TransformerConfigs.encoder_decoder, 'normal')
def _encoder_decoder(c: TransformerConfigs):