From b6bef1d2fe42cf78287303d816fbb5ce539ff37e Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 2 Jul 2022 14:31:16 +0530 Subject: [PATCH] cleanup --- .../basic/autoregressive_experiment.html | 12 +- docs/transformers/mha.html | 115 +++++----- docs/transformers/models.html | 205 +++++++++--------- docs/transformers/positional_encoding.html | 72 +++--- .../basic/autoregressive_experiment.py | 6 +- labml_nn/transformers/mha.py | 7 +- labml_nn/transformers/models.py | 17 +- labml_nn/transformers/positional_encoding.py | 4 +- 8 files changed, 215 insertions(+), 223 deletions(-) diff --git a/docs/transformers/basic/autoregressive_experiment.html b/docs/transformers/basic/autoregressive_experiment.html index 5a9b637c..fdbd657f 100644 --- a/docs/transformers/basic/autoregressive_experiment.html +++ b/docs/transformers/basic/autoregressive_experiment.html @@ -76,10 +76,10 @@
17import torch
-18
-19from labml import experiment
-20from labml.configs import option
-21from labml_helpers.module import Module
+18from torch import nn
+19
+20from labml import experiment
+21from labml.configs import option
 22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
 23from labml_nn.transformers import TransformerConfigs, Encoder
 24from labml_nn.transformers.utils import subsequent_mask
@@ -94,7 +94,7 @@
-
27class AutoregressiveTransformer(Module):
+
27class AutoregressiveTransformer(nn.Module):
@@ -111,7 +111,7 @@
-
31    def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
+
31    def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
diff --git a/docs/transformers/mha.html b/docs/transformers/mha.html index 069f5787..9fb8f486 100644 --- a/docs/transformers/mha.html +++ b/docs/transformers/mha.html @@ -80,10 +80,9 @@ 26from typing import Optional, List 27 28import torch -29from torch import nn as nn +29from torch import nn 30 -31from labml import tracker -32from labml_helpers.module import Module
+31from labml import tracker
@@ -97,7 +96,7 @@
-
35class PrepareForMultiHeadAttention(Module):
+
34class PrepareForMultiHeadAttention(nn.Module):
@@ -108,8 +107,8 @@
-
46    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
-47        super().__init__()
+
45    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
+46        super().__init__()
@@ -121,7 +120,7 @@
-
49        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
+
48        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
@@ -133,7 +132,7 @@
-
51        self.heads = heads
+
50        self.heads = heads
@@ -145,7 +144,7 @@
-
53        self.d_k = d_k
+
52        self.d_k = d_k
@@ -156,7 +155,7 @@
-
55    def forward(self, x: torch.Tensor):
+
54    def forward(self, x: torch.Tensor):
@@ -170,7 +169,7 @@
-
59        head_shape = x.shape[:-1]
+
58        head_shape = x.shape[:-1]
@@ -182,7 +181,7 @@
-
62        x = self.linear(x)
+
61        x = self.linear(x)
@@ -194,7 +193,7 @@
-
65        x = x.view(*head_shape, self.heads, self.d_k)
+
64        x = x.view(*head_shape, self.heads, self.d_k)
@@ -208,7 +207,7 @@
-
68        return x
+
67        return x
@@ -251,7 +250,7 @@ M834 80h400000v40h-400000z">
71class MultiHeadAttention(Module):
+
70class MultiHeadAttention(nn.Module):
@@ -269,7 +268,7 @@ M834 80h400000v40h-400000z">
92    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
+
91    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
@@ -280,7 +279,7 @@ M834 80h400000v40h-400000z">
98        super().__init__()
+
97        super().__init__()
@@ -292,7 +291,7 @@ M834 80h400000v40h-400000z">
101        self.d_k = d_model // heads
+
100        self.d_k = d_model // heads
@@ -304,7 +303,7 @@ M834 80h400000v40h-400000z">
103        self.heads = heads
+
102        self.heads = heads
@@ -319,9 +318,9 @@ M834 80h400000v40h-400000z">
106        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
-107        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
-108        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
+
105        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+106        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+107        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
@@ -334,7 +333,7 @@ M834 80h400000v40h-400000z">
111        self.softmax = nn.Softmax(dim=1)
+
110        self.softmax = nn.Softmax(dim=1)
@@ -346,7 +345,7 @@ M834 80h400000v40h-400000z">
114        self.output = nn.Linear(d_model, d_model)
+
113        self.output = nn.Linear(d_model, d_model)
@@ -358,7 +357,7 @@ M834 80h400000v40h-400000z">
116        self.dropout = nn.Dropout(dropout_prob)
+
115        self.dropout = nn.Dropout(dropout_prob)
@@ -370,7 +369,7 @@ M834 80h400000v40h-400000z">
118        self.scale = 1 / math.sqrt(self.d_k)
+
117        self.scale = 1 / math.sqrt(self.d_k)
@@ -382,7 +381,7 @@ M834 80h400000v40h-400000z">
121        self.attn = None
+
120        self.attn = None
@@ -395,7 +394,7 @@ M834 80h400000v40h-400000z">
123    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
+
122    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
@@ -407,7 +406,7 @@ M834 80h400000v40h-400000z">
131        return torch.einsum('ibhd,jbhd->ijbh', query, key)
+
130        return torch.einsum('ibhd,jbhd->ijbh', query, key)
@@ -421,7 +420,7 @@ M834 80h400000v40h-400000z">
133    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
+
132    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
@@ -432,9 +431,9 @@ M834 80h400000v40h-400000z">
139        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
-140        assert mask.shape[1] == key_shape[0]
-141        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
+
138        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
+139        assert mask.shape[1] == key_shape[0]
+140        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
@@ -446,7 +445,7 @@ M834 80h400000v40h-400000z">
144        mask = mask.unsqueeze(-1)
+
143        mask = mask.unsqueeze(-1)
@@ -459,7 +458,7 @@ M834 80h400000v40h-400000z">
147        return mask
+
146        return mask
@@ -482,11 +481,11 @@ M834 80h400000v40h-400000z">
149    def forward(self, *,
-150                query: torch.Tensor,
-151                key: torch.Tensor,
-152                value: torch.Tensor,
-153                mask: Optional[torch.Tensor] = None):
+
148    def forward(self, *,
+149                query: torch.Tensor,
+150                key: torch.Tensor,
+151                value: torch.Tensor,
+152                mask: Optional[torch.Tensor] = None):
@@ -502,10 +501,10 @@ M834 80h400000v40h-400000z">
165        seq_len, batch_size, _ = query.shape
-166
-167        if mask is not None:
-168            mask = self.prepare_mask(mask, query.shape, key.shape)
+
164        seq_len, batch_size, _ = query.shape
+165
+166        if mask is not None:
+167            mask = self.prepare_mask(mask, query.shape, key.shape)
@@ -521,9 +520,9 @@ M834 80h400000v40h-400000z">
172        query = self.query(query)
-173        key = self.key(key)
-174        value = self.value(value)
+
171        query = self.query(query)
+172        key = self.key(key)
+173        value = self.value(value)
@@ -536,7 +535,7 @@ M834 80h400000v40h-400000z">
178        scores = self.get_scores(query, key)
+
177        scores = self.get_scores(query, key)
@@ -559,7 +558,7 @@ M834 80h400000v40h-400000z">
181        scores *= self.scale
+
180        scores *= self.scale
@@ -571,8 +570,8 @@ M834 80h400000v40h-400000z">
184        if mask is not None:
-185            scores = scores.masked_fill(mask == 0, float('-inf'))
+
183        if mask is not None:
+184            scores = scores.masked_fill(mask == 0, float('-inf'))
@@ -595,7 +594,7 @@ M834 80h400000v40h-400000z">
189        attn = self.softmax(scores)
+
188        attn = self.softmax(scores)
@@ -607,7 +606,7 @@ M834 80h400000v40h-400000z">
192        tracker.debug('attn', attn)
+
191        tracker.debug('attn', attn)
@@ -619,7 +618,7 @@ M834 80h400000v40h-400000z">
195        attn = self.dropout(attn)
+
194        attn = self.dropout(attn)
@@ -642,7 +641,7 @@ M834 80h400000v40h-400000z">
199        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
+
198        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
@@ -654,7 +653,7 @@ M834 80h400000v40h-400000z">
202        self.attn = attn.detach()
+
201        self.attn = attn.detach()
@@ -666,7 +665,7 @@ M834 80h400000v40h-400000z">
205        x = x.reshape(seq_len, batch_size, -1)
+
204        x = x.reshape(seq_len, batch_size, -1)
@@ -678,7 +677,7 @@ M834 80h400000v40h-400000z">
208        return self.output(x)
+
207        return self.output(x)
+18 +19from labml_nn.utils import clone_module_list +20from .feed_forward import FeedForward +21from .mha import MultiHeadAttention +22from .positional_encoding import get_positional_encoding
@@ -95,7 +94,7 @@
-
26class EmbeddingsWithPositionalEncoding(Module):
+
25class EmbeddingsWithPositionalEncoding(nn.Module):
@@ -106,11 +105,11 @@
-
33    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-34        super().__init__()
-35        self.linear = nn.Embedding(n_vocab, d_model)
-36        self.d_model = d_model
-37        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
+
32    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+33        super().__init__()
+34        self.linear = nn.Embedding(n_vocab, d_model)
+35        self.d_model = d_model
+36        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
@@ -121,9 +120,9 @@
-
39    def forward(self, x: torch.Tensor):
-40        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
-41        return self.linear(x) * math.sqrt(self.d_model) + pe
+
38    def forward(self, x: torch.Tensor):
+39        pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
+40        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -136,7 +135,7 @@
-
44class EmbeddingsWithLearnedPositionalEncoding(Module):
+
43class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
@@ -147,11 +146,11 @@
-
51    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-52        super().__init__()
-53        self.linear = nn.Embedding(n_vocab, d_model)
-54        self.d_model = d_model
-55        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
+
50    def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+51        super().__init__()
+52        self.linear = nn.Embedding(n_vocab, d_model)
+53        self.d_model = d_model
+54        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
@@ -162,9 +161,9 @@
-
57    def forward(self, x: torch.Tensor):
-58        pe = self.positional_encodings[:x.shape[0]]
-59        return self.linear(x) * math.sqrt(self.d_model) + pe
+
56    def forward(self, x: torch.Tensor):
+57        pe = self.positional_encodings[:x.shape[0]]
+58        return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -179,7 +178,7 @@
-
62class TransformerLayer(Module):
+
61class TransformerLayer(nn.Module):
@@ -200,12 +199,12 @@
-
80    def __init__(self, *,
-81                 d_model: int,
-82                 self_attn: MultiHeadAttention,
-83                 src_attn: MultiHeadAttention = None,
-84                 feed_forward: FeedForward,
-85                 dropout_prob: float):
+
79    def __init__(self, *,
+80                 d_model: int,
+81                 self_attn: MultiHeadAttention,
+82                 src_attn: MultiHeadAttention = None,
+83                 feed_forward: FeedForward,
+84                 dropout_prob: float):
@@ -216,16 +215,16 @@
-
93        super().__init__()
-94        self.size = d_model
-95        self.self_attn = self_attn
-96        self.src_attn = src_attn
-97        self.feed_forward = feed_forward
-98        self.dropout = nn.Dropout(dropout_prob)
-99        self.norm_self_attn = nn.LayerNorm([d_model])
-100        if self.src_attn is not None:
-101            self.norm_src_attn = nn.LayerNorm([d_model])
-102        self.norm_ff = nn.LayerNorm([d_model])
+
92        super().__init__()
+93        self.size = d_model
+94        self.self_attn = self_attn
+95        self.src_attn = src_attn
+96        self.feed_forward = feed_forward
+97        self.dropout = nn.Dropout(dropout_prob)
+98        self.norm_self_attn = nn.LayerNorm([d_model])
+99        if self.src_attn is not None:
+100            self.norm_src_attn = nn.LayerNorm([d_model])
+101        self.norm_ff = nn.LayerNorm([d_model])
@@ -237,7 +236,7 @@
-
104        self.is_save_ff_input = False
+
103        self.is_save_ff_input = False
@@ -248,11 +247,11 @@
-
106    def forward(self, *,
-107                x: torch.Tensor,
-108                mask: torch.Tensor,
-109                src: torch.Tensor = None,
-110                src_mask: torch.Tensor = None):
+
105    def forward(self, *,
+106                x: torch.Tensor,
+107                mask: torch.Tensor,
+108                src: torch.Tensor = None,
+109                src_mask: torch.Tensor = None):
@@ -264,7 +263,7 @@
-
112        z = self.norm_self_attn(x)
+
111        z = self.norm_self_attn(x)
@@ -276,7 +275,7 @@
-
114        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
+
113        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
@@ -288,7 +287,7 @@
-
116        x = x + self.dropout(self_attn)
+
115        x = x + self.dropout(self_attn)
@@ -300,7 +299,7 @@
-
121        if src is not None:
+
120        if src is not None:
@@ -312,7 +311,7 @@
-
123            z = self.norm_src_attn(x)
+
122            z = self.norm_src_attn(x)
@@ -324,7 +323,7 @@
-
125            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
+
124            attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
@@ -336,7 +335,7 @@
-
127            x = x + self.dropout(attn_src)
+
126            x = x + self.dropout(attn_src)
@@ -348,7 +347,7 @@
-
130        z = self.norm_ff(x)
+
129        z = self.norm_ff(x)
@@ -360,8 +359,8 @@
-
132        if self.is_save_ff_input:
-133            self.ff_input = z.clone()
+
131        if self.is_save_ff_input:
+132            self.ff_input = z.clone()
@@ -373,7 +372,7 @@
-
135        ff = self.feed_forward(z)
+
134        ff = self.feed_forward(z)
@@ -385,9 +384,9 @@
-
137        x = x + self.dropout(ff)
-138
-139        return x
+
136        x = x + self.dropout(ff)
+137
+138        return x
@@ -400,7 +399,7 @@
-
142class Encoder(Module):
+
141class Encoder(nn.Module):
@@ -411,8 +410,8 @@
-
149    def __init__(self, layer: TransformerLayer, n_layers: int):
-150        super().__init__()
+
148    def __init__(self, layer: TransformerLayer, n_layers: int):
+149        super().__init__()
@@ -424,7 +423,7 @@
-
152        self.layers = clone_module_list(layer, n_layers)
+
151        self.layers = clone_module_list(layer, n_layers)
@@ -436,7 +435,7 @@
-
154        self.norm = nn.LayerNorm([layer.size])
+
153        self.norm = nn.LayerNorm([layer.size])
@@ -447,7 +446,7 @@
-
156    def forward(self, x: torch.Tensor, mask: torch.Tensor):
+
155    def forward(self, x: torch.Tensor, mask: torch.Tensor):
@@ -459,8 +458,8 @@
-
158        for layer in self.layers:
-159            x = layer(x=x, mask=mask)
+
157        for layer in self.layers:
+158            x = layer(x=x, mask=mask)
@@ -472,7 +471,7 @@
-
161        return self.norm(x)
+
160        return self.norm(x)
@@ -485,7 +484,7 @@
-
164class Decoder(Module):
+
163class Decoder(nn.Module):
@@ -496,8 +495,8 @@
-
171    def __init__(self, layer: TransformerLayer, n_layers: int):
-172        super().__init__()
+
170    def __init__(self, layer: TransformerLayer, n_layers: int):
+171        super().__init__()
@@ -509,7 +508,7 @@
-
174        self.layers = clone_module_list(layer, n_layers)
+
173        self.layers = clone_module_list(layer, n_layers)
@@ -521,7 +520,7 @@
-
176        self.norm = nn.LayerNorm([layer.size])
+
175        self.norm = nn.LayerNorm([layer.size])
@@ -532,7 +531,7 @@
-
178    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
177    def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -544,8 +543,8 @@
-
180        for layer in self.layers:
-181            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
+
179        for layer in self.layers:
+180            x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
@@ -557,7 +556,7 @@
-
183        return self.norm(x)
+
182        return self.norm(x)
@@ -572,7 +571,7 @@
-
186class Generator(Module):
+
185class Generator(nn.Module):
@@ -583,9 +582,9 @@
-
196    def __init__(self, n_vocab: int, d_model: int):
-197        super().__init__()
-198        self.projection = nn.Linear(d_model, n_vocab)
+
195    def __init__(self, n_vocab: int, d_model: int):
+196        super().__init__()
+197        self.projection = nn.Linear(d_model, n_vocab)
@@ -596,8 +595,8 @@
-
200    def forward(self, x):
-201        return self.projection(x)
+
199    def forward(self, x):
+200        return self.projection(x)
@@ -610,7 +609,7 @@
-
204class EncoderDecoder(Module):
+
203class EncoderDecoder(nn.Module):
@@ -621,13 +620,13 @@
-
211    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
-212        super().__init__()
-213        self.encoder = encoder
-214        self.decoder = decoder
-215        self.src_embed = src_embed
-216        self.tgt_embed = tgt_embed
-217        self.generator = generator
+
210    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
+211        super().__init__()
+212        self.encoder = encoder
+213        self.decoder = decoder
+214        self.src_embed = src_embed
+215        self.tgt_embed = tgt_embed
+216        self.generator = generator
@@ -639,9 +638,9 @@
-
221        for p in self.parameters():
-222            if p.dim() > 1:
-223                nn.init.xavier_uniform_(p)
+
220        for p in self.parameters():
+221            if p.dim() > 1:
+222                nn.init.xavier_uniform_(p)
@@ -652,7 +651,7 @@
-
225    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+
224    def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -664,7 +663,7 @@
-
227        enc = self.encode(src, src_mask)
+
226        enc = self.encode(src, src_mask)
@@ -676,7 +675,7 @@
-
229        return self.decode(enc, src_mask, tgt, tgt_mask)
+
228        return self.decode(enc, src_mask, tgt, tgt_mask)
@@ -687,8 +686,8 @@
-
231    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
-232        return self.encoder(self.src_embed(src), src_mask)
+
230    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
+231        return self.encoder(self.src_embed(src), src_mask)
@@ -699,8 +698,8 @@
-
234    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
-235        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
+
233    def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
+234        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
+27import torch.nn as nn
@@ -92,7 +90,7 @@
-
32class PositionalEncoding(Module):
+
30class PositionalEncoding(nn.Module):
@@ -103,11 +101,11 @@
-
33    def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
-34        super().__init__()
-35        self.dropout = nn.Dropout(dropout_prob)
-36
-37        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
+
31    def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
+32        super().__init__()
+33        self.dropout = nn.Dropout(dropout_prob)
+34
+35        self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
@@ -118,11 +116,11 @@
-
39    def forward(self, x: torch.Tensor):
-40        pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
-41        x = x + pe
-42        x = self.dropout(x)
-43        return x
+
37    def forward(self, x: torch.Tensor):
+38        pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
+39        x = x + pe
+40        x = self.dropout(x)
+41        return x
@@ -133,7 +131,7 @@
-
46def get_positional_encoding(d_model: int, max_len: int = 5000):
+
44def get_positional_encoding(d_model: int, max_len: int = 5000):
@@ -145,7 +143,7 @@
-
48    encodings = torch.zeros(max_len, d_model)
+
46    encodings = torch.zeros(max_len, d_model)
@@ -157,7 +155,7 @@
-
50    position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
+
48    position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
@@ -169,7 +167,7 @@
-
52    two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
+
50    two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
@@ -181,7 +179,7 @@
-
54    div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
+
52    div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
@@ -193,7 +191,7 @@
-
56    encodings[:, 0::2] = torch.sin(position * div_term)
+
54    encodings[:, 0::2] = torch.sin(position * div_term)
@@ -205,7 +203,7 @@
-
58    encodings[:, 1::2] = torch.cos(position * div_term)
+
56    encodings[:, 1::2] = torch.cos(position * div_term)
@@ -217,9 +215,9 @@
-
61    encodings = encodings.unsqueeze(1).requires_grad_(False)
-62
-63    return encodings
+
59    encodings = encodings.unsqueeze(1).requires_grad_(False)
+60
+61    return encodings
@@ -230,19 +228,19 @@
-
66def _test_positional_encoding():
-67    import matplotlib.pyplot as plt
-68
-69    plt.figure(figsize=(15, 5))
-70    pe = get_positional_encoding(20, 100)
-71    plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
-72    plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
-73    plt.title("Positional encoding")
-74    plt.show()
-75
-76
-77if __name__ == '__main__':
-78    _test_positional_encoding()
+
64def _test_positional_encoding():
+65    import matplotlib.pyplot as plt
+66
+67    plt.figure(figsize=(15, 5))
+68    pe = get_positional_encoding(20, 100)
+69    plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
+70    plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
+71    plt.title("Positional encoding")
+72    plt.show()
+73
+74
+75if __name__ == '__main__':
+76    _test_positional_encoding()