mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 04:37:46 +08:00
learned positional encodings
This commit is contained in:
@ -24,6 +24,18 @@ class EmbeddingsWithPositionalEncoding(Module):
|
|||||||
return self.linear(x) * math.sqrt(self.d_model) + pe
|
return self.linear(x) * math.sqrt(self.d_model) + pe
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsWithLearnedPositionalEncoding(Module):
|
||||||
|
def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Embedding(n_vocab, d_model)
|
||||||
|
self.d_model = d_model
|
||||||
|
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model))
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
pe = self.positional_encodings[:x.shape[0]]
|
||||||
|
return self.linear(x) * math.sqrt(self.d_model) + pe
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(Module):
|
class FeedForward(Module):
|
||||||
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -227,6 +239,17 @@ def _tgt_embed_with_positional(c: TransformerConfigs):
|
|||||||
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
||||||
|
|
||||||
|
|
||||||
|
@option(TransformerConfigs.src_embed, 'learned_pos')
|
||||||
|
def _src_embed_with_learned_positional(c: TransformerConfigs):
|
||||||
|
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
|
||||||
|
|
||||||
|
|
||||||
|
@option(TransformerConfigs.tgt_embed, 'learned_pos')
|
||||||
|
def _tgt_embed_with_learned_positional(c: TransformerConfigs):
|
||||||
|
return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@option(TransformerConfigs.encoder_decoder, 'normal')
|
@option(TransformerConfigs.encoder_decoder, 'normal')
|
||||||
def _encoder_decoder(c: TransformerConfigs):
|
def _encoder_decoder(c: TransformerConfigs):
|
||||||
return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
|
return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)
|
||||||
|
|||||||
Reference in New Issue
Block a user