From 4ae3d770a91ca4d52a3ecabd073850b502dc693d Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 25 Aug 2020 15:41:10 +0530 Subject: [PATCH] learned positional encodings --- transformers/__init__.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/transformers/__init__.py b/transformers/__init__.py index e5708f9d..86906796 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -24,6 +24,18 @@ class EmbeddingsWithPositionalEncoding(Module): return self.linear(x) * math.sqrt(self.d_model) + pe +class EmbeddingsWithLearnedPositionalEncoding(Module): + def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000): + super().__init__() + self.linear = nn.Embedding(n_vocab, d_model) + self.d_model = d_model + self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model)) + + def __call__(self, x: torch.Tensor): + pe = self.positional_encodings[:x.shape[0]] + return self.linear(x) * math.sqrt(self.d_model) + pe + + class FeedForward(Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() @@ -227,6 +239,17 @@ def _tgt_embed_with_positional(c: TransformerConfigs): return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab) +@option(TransformerConfigs.src_embed, 'learned_pos') +def _src_embed_with_learned_positional(c: TransformerConfigs): + return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab) + + +@option(TransformerConfigs.tgt_embed, 'learned_pos') +def _tgt_embed_with_learned_positional(c: TransformerConfigs): + return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab) + + + @option(TransformerConfigs.encoder_decoder, 'normal') def _encoder_decoder(c: TransformerConfigs): return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)