diff --git a/docs/transformers/basic/autoregressive_experiment.html b/docs/transformers/basic/autoregressive_experiment.html index 5a9b637c..fdbd657f 100644 --- a/docs/transformers/basic/autoregressive_experiment.html +++ b/docs/transformers/basic/autoregressive_experiment.html @@ -76,10 +76,10 @@
17import torch
-18
-19from labml import experiment
-20from labml.configs import option
-21from labml_helpers.module import Module
+18from torch import nn
+19
+20from labml import experiment
+21from labml.configs import option
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import TransformerConfigs, Encoder
24from labml_nn.transformers.utils import subsequent_mask
27class AutoregressiveTransformer(Module):
27class AutoregressiveTransformer(nn.Module):
31 def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
31 def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
35class PrepareForMultiHeadAttention(Module):
34class PrepareForMultiHeadAttention(nn.Module):
46 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
-47 super().__init__()
45 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
+46 super().__init__()
49 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
48 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
51 self.heads = heads
50 self.heads = heads
53 self.d_k = d_k
52 self.d_k = d_k
55 def forward(self, x: torch.Tensor):
54 def forward(self, x: torch.Tensor):
59 head_shape = x.shape[:-1]
58 head_shape = x.shape[:-1]
62 x = self.linear(x)
61 x = self.linear(x)
65 x = x.view(*head_shape, self.heads, self.d_k)
64 x = x.view(*head_shape, self.heads, self.d_k)
68 return x
67 return x
71class MultiHeadAttention(Module):
70class MultiHeadAttention(nn.Module):
92 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
91 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
98 super().__init__()
97 super().__init__()
101 self.d_k = d_model // heads
100 self.d_k = d_model // heads
103 self.heads = heads
102 self.heads = heads
106 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) -107 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) -108 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
105 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+106 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+107 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
111 self.softmax = nn.Softmax(dim=1)
110 self.softmax = nn.Softmax(dim=1)
114 self.output = nn.Linear(d_model, d_model)
113 self.output = nn.Linear(d_model, d_model)
116 self.dropout = nn.Dropout(dropout_prob)
115 self.dropout = nn.Dropout(dropout_prob)
118 self.scale = 1 / math.sqrt(self.d_k)
117 self.scale = 1 / math.sqrt(self.d_k)
121 self.attn = None
120 self.attn = None
123 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
122 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
131 return torch.einsum('ibhd,jbhd->ijbh', query, key)
130 return torch.einsum('ibhd,jbhd->ijbh', query, key)
133 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
132 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
139 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0] -140 assert mask.shape[1] == key_shape[0] -141 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
138 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
+139 assert mask.shape[1] == key_shape[0]
+140 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
144 mask = mask.unsqueeze(-1)
143 mask = mask.unsqueeze(-1)
147 return mask
146 return mask
149 def forward(self, *, -150 query: torch.Tensor, -151 key: torch.Tensor, -152 value: torch.Tensor, -153 mask: Optional[torch.Tensor] = None):
148 def forward(self, *,
+149 query: torch.Tensor,
+150 key: torch.Tensor,
+151 value: torch.Tensor,
+152 mask: Optional[torch.Tensor] = None):
165 seq_len, batch_size, _ = query.shape -166 -167 if mask is not None: -168 mask = self.prepare_mask(mask, query.shape, key.shape)
164 seq_len, batch_size, _ = query.shape
+165
+166 if mask is not None:
+167 mask = self.prepare_mask(mask, query.shape, key.shape)
172 query = self.query(query) -173 key = self.key(key) -174 value = self.value(value)
171 query = self.query(query)
+172 key = self.key(key)
+173 value = self.value(value)
178 scores = self.get_scores(query, key)
177 scores = self.get_scores(query, key)
181 scores *= self.scale
180 scores *= self.scale
184 if mask is not None: -185 scores = scores.masked_fill(mask == 0, float('-inf'))
183 if mask is not None:
+184 scores = scores.masked_fill(mask == 0, float('-inf'))
189 attn = self.softmax(scores)
188 attn = self.softmax(scores)
192 tracker.debug('attn', attn)
191 tracker.debug('attn', attn)
195 attn = self.dropout(attn)
194 attn = self.dropout(attn)
199 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
198 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
202 self.attn = attn.detach()
201 self.attn = attn.detach()
205 x = x.reshape(seq_len, batch_size, -1)
204 x = x.reshape(seq_len, batch_size, -1)
208 return self.output(x)
207 return self.output(x)
26class EmbeddingsWithPositionalEncoding(Module):
25class EmbeddingsWithPositionalEncoding(nn.Module):
33 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-34 super().__init__()
-35 self.linear = nn.Embedding(n_vocab, d_model)
-36 self.d_model = d_model
-37 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
32 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+33 super().__init__()
+34 self.linear = nn.Embedding(n_vocab, d_model)
+35 self.d_model = d_model
+36 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
39 def forward(self, x: torch.Tensor):
-40 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
-41 return self.linear(x) * math.sqrt(self.d_model) + pe
38 def forward(self, x: torch.Tensor):
+39 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
+40 return self.linear(x) * math.sqrt(self.d_model) + pe
44class EmbeddingsWithLearnedPositionalEncoding(Module):
43class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
51 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-52 super().__init__()
-53 self.linear = nn.Embedding(n_vocab, d_model)
-54 self.d_model = d_model
-55 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
50 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+51 super().__init__()
+52 self.linear = nn.Embedding(n_vocab, d_model)
+53 self.d_model = d_model
+54 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
57 def forward(self, x: torch.Tensor):
-58 pe = self.positional_encodings[:x.shape[0]]
-59 return self.linear(x) * math.sqrt(self.d_model) + pe
56 def forward(self, x: torch.Tensor):
+57 pe = self.positional_encodings[:x.shape[0]]
+58 return self.linear(x) * math.sqrt(self.d_model) + pe
62class TransformerLayer(Module):
61class TransformerLayer(nn.Module):
80 def __init__(self, *,
-81 d_model: int,
-82 self_attn: MultiHeadAttention,
-83 src_attn: MultiHeadAttention = None,
-84 feed_forward: FeedForward,
-85 dropout_prob: float):
79 def __init__(self, *,
+80 d_model: int,
+81 self_attn: MultiHeadAttention,
+82 src_attn: MultiHeadAttention = None,
+83 feed_forward: FeedForward,
+84 dropout_prob: float):
93 super().__init__()
-94 self.size = d_model
-95 self.self_attn = self_attn
-96 self.src_attn = src_attn
-97 self.feed_forward = feed_forward
-98 self.dropout = nn.Dropout(dropout_prob)
-99 self.norm_self_attn = nn.LayerNorm([d_model])
-100 if self.src_attn is not None:
-101 self.norm_src_attn = nn.LayerNorm([d_model])
-102 self.norm_ff = nn.LayerNorm([d_model])
92 super().__init__()
+93 self.size = d_model
+94 self.self_attn = self_attn
+95 self.src_attn = src_attn
+96 self.feed_forward = feed_forward
+97 self.dropout = nn.Dropout(dropout_prob)
+98 self.norm_self_attn = nn.LayerNorm([d_model])
+99 if self.src_attn is not None:
+100 self.norm_src_attn = nn.LayerNorm([d_model])
+101 self.norm_ff = nn.LayerNorm([d_model])
104 self.is_save_ff_input = False
103 self.is_save_ff_input = False
106 def forward(self, *,
-107 x: torch.Tensor,
-108 mask: torch.Tensor,
-109 src: torch.Tensor = None,
-110 src_mask: torch.Tensor = None):
105 def forward(self, *,
+106 x: torch.Tensor,
+107 mask: torch.Tensor,
+108 src: torch.Tensor = None,
+109 src_mask: torch.Tensor = None):
112 z = self.norm_self_attn(x)
111 z = self.norm_self_attn(x)
114 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
113 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
116 x = x + self.dropout(self_attn)
115 x = x + self.dropout(self_attn)
121 if src is not None:
120 if src is not None:
123 z = self.norm_src_attn(x)
122 z = self.norm_src_attn(x)
125 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
124 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
127 x = x + self.dropout(attn_src)
126 x = x + self.dropout(attn_src)
130 z = self.norm_ff(x)
129 z = self.norm_ff(x)
132 if self.is_save_ff_input:
-133 self.ff_input = z.clone()
131 if self.is_save_ff_input:
+132 self.ff_input = z.clone()
135 ff = self.feed_forward(z)
134 ff = self.feed_forward(z)
137 x = x + self.dropout(ff)
-138
-139 return x
136 x = x + self.dropout(ff)
+137
+138 return x
142class Encoder(Module):
141class Encoder(nn.Module):
149 def __init__(self, layer: TransformerLayer, n_layers: int):
-150 super().__init__()
148 def __init__(self, layer: TransformerLayer, n_layers: int):
+149 super().__init__()
152 self.layers = clone_module_list(layer, n_layers)
151 self.layers = clone_module_list(layer, n_layers)
154 self.norm = nn.LayerNorm([layer.size])
153 self.norm = nn.LayerNorm([layer.size])
156 def forward(self, x: torch.Tensor, mask: torch.Tensor):
155 def forward(self, x: torch.Tensor, mask: torch.Tensor):
158 for layer in self.layers:
-159 x = layer(x=x, mask=mask)
157 for layer in self.layers:
+158 x = layer(x=x, mask=mask)
161 return self.norm(x)
160 return self.norm(x)
164class Decoder(Module):
163class Decoder(nn.Module):
171 def __init__(self, layer: TransformerLayer, n_layers: int):
-172 super().__init__()
170 def __init__(self, layer: TransformerLayer, n_layers: int):
+171 super().__init__()
174 self.layers = clone_module_list(layer, n_layers)
173 self.layers = clone_module_list(layer, n_layers)
176 self.norm = nn.LayerNorm([layer.size])
175 self.norm = nn.LayerNorm([layer.size])
178 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
177 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
180 for layer in self.layers:
-181 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
179 for layer in self.layers:
+180 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
183 return self.norm(x)
182 return self.norm(x)
186class Generator(Module):
185class Generator(nn.Module):
196 def __init__(self, n_vocab: int, d_model: int):
-197 super().__init__()
-198 self.projection = nn.Linear(d_model, n_vocab)
195 def __init__(self, n_vocab: int, d_model: int):
+196 super().__init__()
+197 self.projection = nn.Linear(d_model, n_vocab)
200 def forward(self, x):
-201 return self.projection(x)
199 def forward(self, x):
+200 return self.projection(x)
204class EncoderDecoder(Module):
203class EncoderDecoder(nn.Module):
211 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
-212 super().__init__()
-213 self.encoder = encoder
-214 self.decoder = decoder
-215 self.src_embed = src_embed
-216 self.tgt_embed = tgt_embed
-217 self.generator = generator
210 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
+211 super().__init__()
+212 self.encoder = encoder
+213 self.decoder = decoder
+214 self.src_embed = src_embed
+215 self.tgt_embed = tgt_embed
+216 self.generator = generator
221 for p in self.parameters():
-222 if p.dim() > 1:
-223 nn.init.xavier_uniform_(p)
220 for p in self.parameters():
+221 if p.dim() > 1:
+222 nn.init.xavier_uniform_(p)
225 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
224 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
227 enc = self.encode(src, src_mask)
226 enc = self.encode(src, src_mask)
229 return self.decode(enc, src_mask, tgt, tgt_mask)
228 return self.decode(enc, src_mask, tgt, tgt_mask)
231 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
-232 return self.encoder(self.src_embed(src), src_mask)
230 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
+231 return self.encoder(self.src_embed(src), src_mask)
234 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
-235 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
233 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
+234 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
32class PositionalEncoding(Module):
30class PositionalEncoding(nn.Module):
33 def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
-34 super().__init__()
-35 self.dropout = nn.Dropout(dropout_prob)
-36
-37 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
31 def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
+32 super().__init__()
+33 self.dropout = nn.Dropout(dropout_prob)
+34
+35 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len), False)
39 def forward(self, x: torch.Tensor):
-40 pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
-41 x = x + pe
-42 x = self.dropout(x)
-43 return x
37 def forward(self, x: torch.Tensor):
+38 pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
+39 x = x + pe
+40 x = self.dropout(x)
+41 return x
46def get_positional_encoding(d_model: int, max_len: int = 5000):
44def get_positional_encoding(d_model: int, max_len: int = 5000):
48 encodings = torch.zeros(max_len, d_model)
46 encodings = torch.zeros(max_len, d_model)
50 position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
48 position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
52 two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
50 two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
54 div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
52 div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
56 encodings[:, 0::2] = torch.sin(position * div_term)
54 encodings[:, 0::2] = torch.sin(position * div_term)
58 encodings[:, 1::2] = torch.cos(position * div_term)
56 encodings[:, 1::2] = torch.cos(position * div_term)
61 encodings = encodings.unsqueeze(1).requires_grad_(False)
-62
-63 return encodings
59 encodings = encodings.unsqueeze(1).requires_grad_(False)
+60
+61 return encodings
66def _test_positional_encoding():
-67 import matplotlib.pyplot as plt
-68
-69 plt.figure(figsize=(15, 5))
-70 pe = get_positional_encoding(20, 100)
-71 plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
-72 plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
-73 plt.title("Positional encoding")
-74 plt.show()
-75
-76
-77if __name__ == '__main__':
-78 _test_positional_encoding()
64def _test_positional_encoding():
+65 import matplotlib.pyplot as plt
+66
+67 plt.figure(figsize=(15, 5))
+68 pe = get_positional_encoding(20, 100)
+69 plt.plot(np.arange(100), pe[:, 0, 4:8].numpy())
+70 plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])
+71 plt.title("Positional encoding")
+72 plt.show()
+73
+74
+75if __name__ == '__main__':
+76 _test_positional_encoding()