learned positional encodings

This commit is contained in:
Varuna Jayasiri
2020-08-25 15:41:10 +05:30
parent 5a7a2e0525
commit 4ae3d770a9

View File

@ -24,6 +24,18 @@ class EmbeddingsWithPositionalEncoding(Module):
return self.linear(x) * math.sqrt(self.d_model) + pe return self.linear(x) * math.sqrt(self.d_model) + pe
class EmbeddingsWithLearnedPositionalEncoding(Module):
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
super().__init__()
self.linear = nn.Embedding(n_vocab, d_model)
self.d_model = d_model
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model))
def __call__(self, x: torch.Tensor):
pe = self.positional_encodings[:x.shape[0]]
return self.linear(x) * math.sqrt(self.d_model) + pe
class FeedForward(Module): class FeedForward(Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__() super().__init__()
@ -227,6 +239,17 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab) return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
@option(TransformerConfigs.src_embed, 'learned_pos')
def _src_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
@option(TransformerConfigs.tgt_embed, 'learned_pos')
def _tgt_embed_with_learned_positional(c: TransformerConfigs):
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
@option(TransformerConfigs.encoder_decoder, 'normal') @option(TransformerConfigs.encoder_decoder, 'normal')
def _encoder_decoder(c: TransformerConfigs): def _encoder_decoder(c: TransformerConfigs):
return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator) return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)