From ee5a34aa598a03319692fa988eb25e6a3885f156 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Tue, 28 Jun 2022 19:02:20 +0530 Subject: [PATCH] experiment links transformer --- docs/sitemap.xml | 2 +- .../basic/autoregressive_experiment.html | 126 +++++----- docs/transformers/mha.html | 128 +++++------ docs/transformers/models.html | 215 +++++++++--------- .../basic/autoregressive_experiment.ipynb | 1 + .../basic/autoregressive_experiment.py | 5 +- labml_nn/transformers/mha.py | 5 +- labml_nn/transformers/models.py | 3 + setup.py | 2 +- 9 files changed, 247 insertions(+), 240 deletions(-) diff --git a/docs/sitemap.xml b/docs/sitemap.xml index b9e28b79..83d8f53a 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -239,7 +239,7 @@ https://nn.labml.ai/experiments/nlp_autoregression.html - 2022-06-25T16:30:00+00:00 + 2022-06-27T16:30:00+00:00 1.00 diff --git a/docs/transformers/basic/autoregressive_experiment.html b/docs/transformers/basic/autoregressive_experiment.html index a3dd00cf..521a8bcb 100644 --- a/docs/transformers/basic/autoregressive_experiment.html +++ b/docs/transformers/basic/autoregressive_experiment.html @@ -70,19 +70,19 @@ #

Transformer Auto-Regression Experiment

+

Open In Colab Open In Comet

This trains a simple transformer introduced in Attention Is All You Need on an NLP auto-regression task (with Tiny Shakespeare dataset).

-

Open In Colab

-
16import torch
-17
-18from labml import experiment
-19from labml.configs import option
-20from labml_helpers.module import Module
-21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-22from labml_nn.transformers import TransformerConfigs, Encoder
-23from labml_nn.transformers.utils import subsequent_mask
+
17import torch
+18
+19from labml import experiment
+20from labml.configs import option
+21from labml_helpers.module import Module
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+23from labml_nn.transformers import TransformerConfigs, Encoder
+24from labml_nn.transformers.utils import subsequent_mask
@@ -94,7 +94,7 @@
-
26class AutoregressiveTransformer(Module):
+
27class AutoregressiveTransformer(Module):
@@ -111,7 +111,7 @@
-
30    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
+
31    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
@@ -122,10 +122,10 @@
-
37        super().__init__()
-38        self.src_embed = src_embed
-39        self.encoder = encoder
-40        self.generator = generator
+
38        super().__init__()
+39        self.src_embed = src_embed
+40        self.encoder = encoder
+41        self.generator = generator
@@ -137,7 +137,7 @@
-
43        self.mask = None
+
44        self.mask = None
@@ -148,7 +148,7 @@
-
45    def forward(self, x: torch.Tensor):
+
46    def forward(self, x: torch.Tensor):
@@ -160,7 +160,7 @@
-
48        if self.mask is None or self.mask.size(0) != len(x):
+
49        if self.mask is None or self.mask.size(0) != len(x):
@@ -172,7 +172,7 @@
-
50            self.mask = subsequent_mask(len(x)).to(x.device)
+
51            self.mask = subsequent_mask(len(x)).to(x.device)
@@ -184,7 +184,7 @@
-
52        x = self.src_embed(x)
+
53        x = self.src_embed(x)
@@ -196,7 +196,7 @@
-
54        x = self.encoder(x, self.mask)
+
55        x = self.encoder(x, self.mask)
@@ -208,7 +208,7 @@
-
56        x = self.generator(x)
+
57        x = self.generator(x)
@@ -220,7 +220,7 @@
-
60        return x, None
+
61        return x, None
@@ -234,7 +234,7 @@
-
63class Configs(NLPAutoRegressionConfigs):
+
64class Configs(NLPAutoRegressionConfigs):
@@ -246,7 +246,7 @@
-
72    model: AutoregressiveTransformer
+
73    model: AutoregressiveTransformer
@@ -258,7 +258,7 @@
-
74    transformer: TransformerConfigs
+
75    transformer: TransformerConfigs
@@ -270,8 +270,8 @@
-
77@option(Configs.transformer, 'Transformer')
-78def _transformer_configs(c: Configs):
+
78@option(Configs.transformer, 'Transformer')
+79def _transformer_configs(c: Configs):
@@ -283,7 +283,7 @@
-
85    conf = TransformerConfigs()
+
86    conf = TransformerConfigs()
@@ -295,8 +295,8 @@
-
87    conf.n_src_vocab = c.n_tokens
-88    conf.n_tgt_vocab = c.n_tokens
+
88    conf.n_src_vocab = c.n_tokens
+89    conf.n_tgt_vocab = c.n_tokens
@@ -308,7 +308,7 @@
-
90    conf.d_model = c.d_model
+
91    conf.d_model = c.d_model
@@ -320,7 +320,7 @@
-
93    return conf
+
94    return conf
@@ -332,8 +332,8 @@
-
96@option(Configs.model)
-97def _model(c: Configs):
+
97@option(Configs.model)
+98def _model(c: Configs):
@@ -344,11 +344,11 @@
-
101    m = AutoregressiveTransformer(c.transformer.encoder,
-102                                  c.transformer.src_embed,
-103                                  c.transformer.generator).to(c.device)
-104
-105    return m
+
102    m = AutoregressiveTransformer(c.transformer.encoder,
+103                                  c.transformer.src_embed,
+104                                  c.transformer.generator).to(c.device)
+105
+106    return m
@@ -359,7 +359,7 @@
-
108def main():
+
109def main():
@@ -371,7 +371,7 @@
-
110    experiment.create(name="transformer")
+
111    experiment.create(name="transformer")
@@ -383,7 +383,7 @@
-
112    conf = Configs()
+
113    conf = Configs()
@@ -395,7 +395,7 @@
-
114    experiment.configs(conf, {
+
115    experiment.configs(conf, {
@@ -407,7 +407,7 @@
-
116        'tokenizer': 'character',
+
117        'tokenizer': 'character',
@@ -419,7 +419,7 @@
-
118        'prompt_separator': '',
+
119        'prompt_separator': '',
@@ -431,7 +431,7 @@
-
120        'prompt': 'It is ',
+
121        'prompt': 'It is ',
@@ -443,7 +443,7 @@
-
122        'text': 'tiny_shakespeare',
+
123        'text': 'tiny_shakespeare',
@@ -455,7 +455,7 @@
-
125        'seq_len': 512,
+
126        'seq_len': 512,
@@ -467,7 +467,7 @@
-
127        'epochs': 32,
+
128        'epochs': 32,
@@ -479,7 +479,7 @@
-
129        'batch_size': 16,
+
130        'batch_size': 16,
@@ -491,7 +491,7 @@
-
132        'inner_iterations': 10,
+
133        'inner_iterations': 10,
@@ -503,9 +503,9 @@
-
135        'd_model': 256,
-136        'transformer.n_heads': 16,
-137        'transformer.ffn.d_ff': 1024,
+
136        'd_model': 256,
+137        'transformer.n_heads': 16,
+138        'transformer.ffn.d_ff': 1024,
@@ -517,9 +517,9 @@
-
140        'optimizer.optimizer': 'Noam',
-141        'optimizer.learning_rate': 1.,
-142    })
+
141        'optimizer.optimizer': 'Noam',
+142        'optimizer.learning_rate': 1.,
+143    })
@@ -531,7 +531,7 @@
-
145    experiment.add_pytorch_models({'model': conf.model})
+
146    experiment.add_pytorch_models({'model': conf.model})
@@ -543,7 +543,7 @@
-
148    with experiment.start():
+
149    with experiment.start():
@@ -555,7 +555,7 @@
-
150        conf.run()
+
151        conf.run()
@@ -567,8 +567,8 @@
-
154if __name__ == '__main__':
-155    main()
+
155if __name__ == '__main__':
+156    main()

Multi-Headed Attention (MHA)

+

Open In Colab Open In Comet

This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.

Here is the training code that uses a basic transformer with MHA for NLP auto-regression.

Here is an experiment implementation that trains a simple transformer.

-

Open In Colab

-
24import math
-25from typing import Optional, List
-26
-27import torch
-28from torch import nn as nn
-29
-30from labml import tracker
-31from labml_helpers.module import Module
+
25import math
+26from typing import Optional, List
+27
+28import torch
+29from torch import nn as nn
+30
+31from labml import tracker
+32from labml_helpers.module import Module
@@ -97,7 +97,7 @@
-
34class PrepareForMultiHeadAttention(Module):
+
35class PrepareForMultiHeadAttention(Module):
@@ -108,8 +108,8 @@
-
45    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
-46        super().__init__()
+
46    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
+47        super().__init__()
@@ -121,7 +121,7 @@
-
48        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
+
49        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
@@ -133,7 +133,7 @@
-
50        self.heads = heads
+
51        self.heads = heads
@@ -145,7 +145,7 @@
-
52        self.d_k = d_k
+
53        self.d_k = d_k
@@ -156,7 +156,7 @@
-
54    def forward(self, x: torch.Tensor):
+
55    def forward(self, x: torch.Tensor):
@@ -170,7 +170,7 @@
-
58        head_shape = x.shape[:-1]
+
59        head_shape = x.shape[:-1]
@@ -182,7 +182,7 @@
-
61        x = self.linear(x)
+
62        x = self.linear(x)
@@ -194,7 +194,7 @@
-
64        x = x.view(*head_shape, self.heads, self.d_k)
+
65        x = x.view(*head_shape, self.heads, self.d_k)
@@ -208,7 +208,7 @@
-
67        return x
+
68        return x
@@ -251,7 +251,7 @@ M834 80h400000v40h-400000z">
70class MultiHeadAttention(Module):
+
71class MultiHeadAttention(Module):
@@ -269,7 +269,7 @@ M834 80h400000v40h-400000z">
91    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
+
92    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
@@ -280,7 +280,7 @@ M834 80h400000v40h-400000z">
97        super().__init__()
+
98        super().__init__()
@@ -292,7 +292,7 @@ M834 80h400000v40h-400000z">
100        self.d_k = d_model // heads
+
101        self.d_k = d_model // heads
@@ -304,7 +304,7 @@ M834 80h400000v40h-400000z">
102        self.heads = heads
+
103        self.heads = heads
@@ -319,9 +319,9 @@ M834 80h400000v40h-400000z">
105        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
-106        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
-107        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
+
106        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+107        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+108        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
@@ -334,7 +334,7 @@ M834 80h400000v40h-400000z">
110        self.softmax = nn.Softmax(dim=1)
+
111        self.softmax = nn.Softmax(dim=1)
@@ -346,7 +346,7 @@ M834 80h400000v40h-400000z">
113        self.output = nn.Linear(d_model, d_model)
+
114        self.output = nn.Linear(d_model, d_model)
@@ -358,7 +358,7 @@ M834 80h400000v40h-400000z">
115        self.dropout = nn.Dropout(dropout_prob)
+
116        self.dropout = nn.Dropout(dropout_prob)
@@ -370,7 +370,7 @@ M834 80h400000v40h-400000z">
117        self.scale = 1 / math.sqrt(self.d_k)
+
118        self.scale = 1 / math.sqrt(self.d_k)
@@ -382,7 +382,7 @@ M834 80h400000v40h-400000z">
120        self.attn = None
+
121        self.attn = None
@@ -395,7 +395,7 @@ M834 80h400000v40h-400000z">
122    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
123    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -407,7 +407,7 @@ M834 80h400000v40h-400000z">
130        return torch.einsum('ibhd,jbhd->ijbh', query, key)
+
131        return torch.einsum('ibhd,jbhd->ijbh', query, key)
@@ -421,7 +421,7 @@ M834 80h400000v40h-400000z">
132    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
+
133    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
@@ -432,9 +432,9 @@ M834 80h400000v40h-400000z">
138        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
-139        assert mask.shape[1] == key_shape[0]
-140        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
+
139        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
+140        assert mask.shape[1] == key_shape[0]
+141        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
@@ -446,7 +446,7 @@ M834 80h400000v40h-400000z">
143        mask = mask.unsqueeze(-1)
+
144        mask = mask.unsqueeze(-1)
@@ -459,7 +459,7 @@ M834 80h400000v40h-400000z">
146        return mask
+
147        return mask
@@ -482,11 +482,11 @@ M834 80h400000v40h-400000z">
148    def forward(self, *,
-149                query: torch.Tensor,
-150                key: torch.Tensor,
-151                value: torch.Tensor,
-152                mask: Optional[torch.Tensor] = None):
+
149    def forward(self, *,
+150                query: torch.Tensor,
+151                key: torch.Tensor,
+152                value: torch.Tensor,
+153                mask: Optional[torch.Tensor] = None):
@@ -502,10 +502,10 @@ M834 80h400000v40h-400000z">
164        seq_len, batch_size, _ = query.shape
-165
-166        if mask is not None:
-167            mask = self.prepare_mask(mask, query.shape, key.shape)
+
165        seq_len, batch_size, _ = query.shape
+166
+167        if mask is not None:
+168            mask = self.prepare_mask(mask, query.shape, key.shape)
@@ -521,9 +521,9 @@ M834 80h400000v40h-400000z">
171        query = self.query(query)
-172        key = self.key(key)
-173        value = self.value(value)
+
172        query = self.query(query)
+173        key = self.key(key)
+174        value = self.value(value)
@@ -536,7 +536,7 @@ M834 80h400000v40h-400000z">
177        scores = self.get_scores(query, key)
+
178        scores = self.get_scores(query, key)
@@ -559,7 +559,7 @@ M834 80h400000v40h-400000z">
180        scores *= self.scale
+
181        scores *= self.scale
@@ -571,8 +571,8 @@ M834 80h400000v40h-400000z">
183        if mask is not None:
-184            scores = scores.masked_fill(mask == 0, float('-inf'))
+
184        if mask is not None:
+185            scores = scores.masked_fill(mask == 0, float('-inf'))
@@ -595,7 +595,7 @@ M834 80h400000v40h-400000z">
188        attn = self.softmax(scores)
+
189        attn = self.softmax(scores)
@@ -607,7 +607,7 @@ M834 80h400000v40h-400000z">
191        tracker.debug('attn', attn)
+
192        tracker.debug('attn', attn)
@@ -619,7 +619,7 @@ M834 80h400000v40h-400000z">
194        attn = self.dropout(attn)
+
195        attn = self.dropout(attn)
@@ -642,7 +642,7 @@ M834 80h400000v40h-400000z">
198        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
+
199        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
@@ -654,7 +654,7 @@ M834 80h400000v40h-400000z">
201        self.attn = attn.detach()
+
202        self.attn = attn.detach()
@@ -666,7 +666,7 @@ M834 80h400000v40h-400000z">
204        x = x.reshape(seq_len, batch_size, -1)
+
205        x = x.reshape(seq_len, batch_size, -1)
@@ -678,7 +678,7 @@ M834 80h400000v40h-400000z">
207        return self.output(x)
+
208        return self.output(x)

Transformer Encoder and Decoder Models

+

Open In Colab Open In Comet

-
11import math
-12
-13import torch
-14import torch.nn as nn
-15from labml_helpers.module import Module
-16
-17from labml_nn.utils import clone_module_list
-18from .feed_forward import FeedForward
-19from .mha import MultiHeadAttention
-20from .positional_encoding import get_positional_encoding
+
14import math
+15
+16import torch
+17import torch.nn as nn
+18from labml_helpers.module import Module
+19
+20from labml_nn.utils import clone_module_list
+21from .feed_forward import FeedForward
+22from .mha import MultiHeadAttention
+23from .positional_encoding import get_positional_encoding
@@ -94,7 +95,7 @@
-
23class EmbeddingsWithPositionalEncoding(Module):
+
26class EmbeddingsWithPositionalEncoding(Module):
@@ -105,11 +106,11 @@
-
30    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-31        super().__init__()
-32        self.linear = nn.Embedding(n_vocab, d_model)
-33        self.d_model = d_model
-34        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
+
33    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+34        super().__init__()
+35        self.linear = nn.Embedding(n_vocab, d_model)
+36        self.d_model = d_model
+37        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
@@ -120,9 +121,9 @@
-
36    def forward(self, x: torch.Tensor):
-37        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
-38        return self.linear(x) * math.sqrt(self.d_model) + pe
+
39    def forward(self, x: torch.Tensor):
+40        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
+41        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -135,7 +136,7 @@
-
41class EmbeddingsWithLearnedPositionalEncoding(Module):
+
44class EmbeddingsWithLearnedPositionalEncoding(Module):
@@ -146,11 +147,11 @@
-
48    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-49        super().__init__()
-50        self.linear = nn.Embedding(n_vocab, d_model)
-51        self.d_model = d_model
-52        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
+
51    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+52        super().__init__()
+53        self.linear = nn.Embedding(n_vocab, d_model)
+54        self.d_model = d_model
+55        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
@@ -161,9 +162,9 @@
-
54    def forward(self, x: torch.Tensor):
-55        pe = self.positional_encodings[:x.shape[0]]
-56        return self.linear(x) * math.sqrt(self.d_model) + pe
+
57    def forward(self, x: torch.Tensor):
+58        pe = self.positional_encodings[:x.shape[0]]
+59        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -178,7 +179,7 @@
-
59class TransformerLayer(Module):
+
62class TransformerLayer(Module):
@@ -199,12 +200,12 @@
-
77    def __init__(self, *,
-78                 d_model: int,
-79                 self_attn: MultiHeadAttention,
-80                 src_attn: MultiHeadAttention = None,
-81                 feed_forward: FeedForward,
-82                 dropout_prob: float):
+
80    def __init__(self, *,
+81                 d_model: int,
+82                 self_attn: MultiHeadAttention,
+83                 src_attn: MultiHeadAttention = None,
+84                 feed_forward: FeedForward,
+85                 dropout_prob: float):
@@ -215,16 +216,16 @@
-
90        super().__init__()
-91        self.size = d_model
-92        self.self_attn = self_attn
-93        self.src_attn = src_attn
-94        self.feed_forward = feed_forward
-95        self.dropout = nn.Dropout(dropout_prob)
-96        self.norm_self_attn = nn.LayerNorm([d_model])
-97        if self.src_attn is not None:
-98            self.norm_src_attn = nn.LayerNorm([d_model])
-99        self.norm_ff = nn.LayerNorm([d_model])
+
93        super().__init__()
+94        self.size = d_model
+95        self.self_attn = self_attn
+96        self.src_attn = src_attn
+97        self.feed_forward = feed_forward
+98        self.dropout = nn.Dropout(dropout_prob)
+99        self.norm_self_attn = nn.LayerNorm([d_model])
+100        if self.src_attn is not None:
+101            self.norm_src_attn = nn.LayerNorm([d_model])
+102        self.norm_ff = nn.LayerNorm([d_model])
@@ -236,7 +237,7 @@
-
101        self.is_save_ff_input = False
+
104        self.is_save_ff_input = False
@@ -247,11 +248,11 @@
-
103    def forward(self, *,
-104                x: torch.Tensor,
-105                mask: torch.Tensor,
-106                src: torch.Tensor = None,
-107                src_mask: torch.Tensor = None):
+
106    def forward(self, *,
+107                x: torch.Tensor,
+108                mask: torch.Tensor,
+109                src: torch.Tensor = None,
+110                src_mask: torch.Tensor = None):
@@ -263,7 +264,7 @@
-
109        z = self.norm_self_attn(x)
+
112        z = self.norm_self_attn(x)
@@ -275,7 +276,7 @@
-
111        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
+
114        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
@@ -287,7 +288,7 @@
-
113        x = x + self.dropout(self_attn)
+
116        x = x + self.dropout(self_attn)
@@ -299,7 +300,7 @@
-
118        if src is not None:
+
121        if src is not None:
@@ -311,7 +312,7 @@
-
120            z = self.norm_src_attn(x)
+
123            z = self.norm_src_attn(x)
@@ -323,7 +324,7 @@
-
122            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
+
125            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
@@ -335,7 +336,7 @@
-
124            x = x + self.dropout(attn_src)
+
127            x = x + self.dropout(attn_src)
@@ -347,7 +348,7 @@
-
127        z = self.norm_ff(x)
+
130        z = self.norm_ff(x)
@@ -359,8 +360,8 @@
-
129        if self.is_save_ff_input:
-130            self.ff_input = z.clone()
+
132        if self.is_save_ff_input:
+133            self.ff_input = z.clone()
@@ -372,7 +373,7 @@
-
132        ff = self.feed_forward(z)
+
135        ff = self.feed_forward(z)
@@ -384,9 +385,9 @@
-
134        x = x + self.dropout(ff)
-135
-136        return x
+
137        x = x + self.dropout(ff)
+138
+139        return x
@@ -399,7 +400,7 @@
-
139class Encoder(Module):
+
142class Encoder(Module):
@@ -410,8 +411,8 @@
-
146    def __init__(self, layer: TransformerLayer, n_layers: int):
-147        super().__init__()
+
149    def __init__(self, layer: TransformerLayer, n_layers: int):
+150        super().__init__()
@@ -423,7 +424,7 @@
-
149        self.layers = clone_module_list(layer, n_layers)
+
152        self.layers = clone_module_list(layer, n_layers)
@@ -435,7 +436,7 @@
-
151        self.norm = nn.LayerNorm([layer.size])
+
154        self.norm = nn.LayerNorm([layer.size])
@@ -446,7 +447,7 @@
-
153    def forward(self, x: torch.Tensor, mask: torch.Tensor):
+
156    def forward(self, x: torch.Tensor, mask: torch.Tensor):
@@ -458,8 +459,8 @@
-
155        for layer in self.layers:
-156            x = layer(x=x, mask=mask)
+
158        for layer in self.layers:
+159            x = layer(x=x, mask=mask)
@@ -471,7 +472,7 @@
-
158        return self.norm(x)
+
161        return self.norm(x)
@@ -484,7 +485,7 @@
-
161class Decoder(Module):
+
164class Decoder(Module):
@@ -495,8 +496,8 @@
-
168    def __init__(self, layer: TransformerLayer, n_layers: int):
-169        super().__init__()
+
171    def __init__(self, layer: TransformerLayer, n_layers: int):
+172        super().__init__()
@@ -508,7 +509,7 @@
-
171        self.layers = clone_module_list(layer, n_layers)
+
174        self.layers = clone_module_list(layer, n_layers)
@@ -520,7 +521,7 @@
-
173        self.norm = nn.LayerNorm([layer.size])
+
176        self.norm = nn.LayerNorm([layer.size])
@@ -531,7 +532,7 @@
-
175    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
178    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -543,8 +544,8 @@
-
177        for layer in self.layers:
-178            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
+
180        for layer in self.layers:
+181            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
@@ -556,7 +557,7 @@
-
180        return self.norm(x)
+
183        return self.norm(x)
@@ -571,7 +572,7 @@
-
183class Generator(Module):
+
186class Generator(Module):
@@ -582,9 +583,9 @@
-
193    def __init__(self, n_vocab: int, d_model: int):
-194        super().__init__()
-195        self.projection = nn.Linear(d_model, n_vocab)
+
196    def __init__(self, n_vocab: int, d_model: int):
+197        super().__init__()
+198        self.projection = nn.Linear(d_model, n_vocab)
@@ -595,8 +596,8 @@
-
197    def forward(self, x):
-198        return self.projection(x)
+
200    def forward(self, x):
+201        return self.projection(x)
@@ -609,7 +610,7 @@
-
201class EncoderDecoder(Module):
+
204class EncoderDecoder(Module):
@@ -620,13 +621,13 @@
-
208    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
-209        super().__init__()
-210        self.encoder = encoder
-211        self.decoder = decoder
-212        self.src_embed = src_embed
-213        self.tgt_embed = tgt_embed
-214        self.generator = generator
+
211    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
+212        super().__init__()
+213        self.encoder = encoder
+214        self.decoder = decoder
+215        self.src_embed = src_embed
+216        self.tgt_embed = tgt_embed
+217        self.generator = generator
@@ -638,9 +639,9 @@
-
218        for p in self.parameters():
-219            if p.dim() > 1:
-220                nn.init.xavier_uniform_(p)
+
221        for p in self.parameters():
+222            if p.dim() > 1:
+223                nn.init.xavier_uniform_(p)
@@ -651,7 +652,7 @@
-
222    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
225    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -663,7 +664,7 @@
-
224        enc = self.encode(src, src_mask)
+
227        enc = self.encode(src, src_mask)
@@ -675,7 +676,7 @@
-
226        return self.decode(enc, src_mask, tgt, tgt_mask)
+
229        return self.decode(enc, src_mask, tgt, tgt_mask)
@@ -686,8 +687,8 @@
-
228    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
-229        return self.encoder(self.src_embed(src), src_mask)
+
231    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
+232        return self.encoder(self.src_embed(src), src_mask)
@@ -698,8 +699,8 @@
-
231    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
-232        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
+
234    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
+235        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)