diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index b9e28b79..83d8f53a 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -239,7 +239,7 @@
This trains a simple transformer introduced in Attention Is All You Need on an NLP auto-regression task (with Tiny Shakespeare dataset).
-16import torch
-17
-18from labml import experiment
-19from labml.configs import option
-20from labml_helpers.module import Module
-21from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
-22from labml_nn.transformers import TransformerConfigs, Encoder
-23from labml_nn.transformers.utils import subsequent_mask
17import torch
+18
+19from labml import experiment
+20from labml.configs import option
+21from labml_helpers.module import Module
+22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
+23from labml_nn.transformers import TransformerConfigs, Encoder
+24from labml_nn.transformers.utils import subsequent_mask
26class AutoregressiveTransformer(Module):
27class AutoregressiveTransformer(Module):
30 def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
31 def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
37 super().__init__()
-38 self.src_embed = src_embed
-39 self.encoder = encoder
-40 self.generator = generator
38 super().__init__()
+39 self.src_embed = src_embed
+40 self.encoder = encoder
+41 self.generator = generator
43 self.mask = None
44 self.mask = None
45 def forward(self, x: torch.Tensor):
46 def forward(self, x: torch.Tensor):
48 if self.mask is None or self.mask.size(0) != len(x):
49 if self.mask is None or self.mask.size(0) != len(x):
50 self.mask = subsequent_mask(len(x)).to(x.device)
51 self.mask = subsequent_mask(len(x)).to(x.device)
52 x = self.src_embed(x)
53 x = self.src_embed(x)
54 x = self.encoder(x, self.mask)
55 x = self.encoder(x, self.mask)
56 x = self.generator(x)
57 x = self.generator(x)
60 return x, None
61 return x, None
63class Configs(NLPAutoRegressionConfigs):
64class Configs(NLPAutoRegressionConfigs):
72 model: AutoregressiveTransformer
73 model: AutoregressiveTransformer
74 transformer: TransformerConfigs
75 transformer: TransformerConfigs
77@option(Configs.transformer, 'Transformer')
-78def _transformer_configs(c: Configs):
78@option(Configs.transformer, 'Transformer')
+79def _transformer_configs(c: Configs):
85 conf = TransformerConfigs()
86 conf = TransformerConfigs()
87 conf.n_src_vocab = c.n_tokens
-88 conf.n_tgt_vocab = c.n_tokens
88 conf.n_src_vocab = c.n_tokens
+89 conf.n_tgt_vocab = c.n_tokens
90 conf.d_model = c.d_model
91 conf.d_model = c.d_model
93 return conf
94 return conf
96@option(Configs.model)
-97def _model(c: Configs):
97@option(Configs.model)
+98def _model(c: Configs):
101 m = AutoregressiveTransformer(c.transformer.encoder,
-102 c.transformer.src_embed,
-103 c.transformer.generator).to(c.device)
-104
-105 return m
102 m = AutoregressiveTransformer(c.transformer.encoder,
+103 c.transformer.src_embed,
+104 c.transformer.generator).to(c.device)
+105
+106 return m
108def main():
109def main():
110 experiment.create(name="transformer")
111 experiment.create(name="transformer")
112 conf = Configs()
113 conf = Configs()
114 experiment.configs(conf, {
115 experiment.configs(conf, {
116 'tokenizer': 'character',
117 'tokenizer': 'character',
118 'prompt_separator': '',
119 'prompt_separator': '',
120 'prompt': 'It is ',
121 'prompt': 'It is ',
122 'text': 'tiny_shakespeare',
123 'text': 'tiny_shakespeare',
125 'seq_len': 512,
126 'seq_len': 512,
127 'epochs': 32,
128 'epochs': 32,
129 'batch_size': 16,
130 'batch_size': 16,
132 'inner_iterations': 10,
133 'inner_iterations': 10,
135 'd_model': 256,
-136 'transformer.n_heads': 16,
-137 'transformer.ffn.d_ff': 1024,
136 'd_model': 256,
+137 'transformer.n_heads': 16,
+138 'transformer.ffn.d_ff': 1024,
140 'optimizer.optimizer': 'Noam',
-141 'optimizer.learning_rate': 1.,
-142 })
141 'optimizer.optimizer': 'Noam',
+142 'optimizer.learning_rate': 1.,
+143 })
145 experiment.add_pytorch_models({'model': conf.model})
146 experiment.add_pytorch_models({'model': conf.model})
148 with experiment.start():
149 with experiment.start():
150 conf.run()
151 conf.run()
154if __name__ == '__main__':
-155 main()
155if __name__ == '__main__':
+156 main()
This is a tutorial/implementation of multi-headed attention from paper Attention Is All You Need in PyTorch. The implementation is inspired from Annotated Transformer.
Here is the training code that uses a basic transformer with MHA for NLP auto-regression.
Here is an experiment implementation that trains a simple transformer.
-24import math
-25from typing import Optional, List
-26
-27import torch
-28from torch import nn as nn
-29
-30from labml import tracker
-31from labml_helpers.module import Module
25import math
+26from typing import Optional, List
+27
+28import torch
+29from torch import nn as nn
+30
+31from labml import tracker
+32from labml_helpers.module import Module
34class PrepareForMultiHeadAttention(Module):
35class PrepareForMultiHeadAttention(Module):
45 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
-46 super().__init__()
46 def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
+47 super().__init__()
48 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
49 self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
50 self.heads = heads
51 self.heads = heads
52 self.d_k = d_k
53 self.d_k = d_k
54 def forward(self, x: torch.Tensor):
55 def forward(self, x: torch.Tensor):
58 head_shape = x.shape[:-1]
59 head_shape = x.shape[:-1]
61 x = self.linear(x)
62 x = self.linear(x)
64 x = x.view(*head_shape, self.heads, self.d_k)
65 x = x.view(*head_shape, self.heads, self.d_k)
67 return x
68 return x
70class MultiHeadAttention(Module):
71class MultiHeadAttention(Module):
91 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
92 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
97 super().__init__()
98 super().__init__()
100 self.d_k = d_model // heads
101 self.d_k = d_model // heads
102 self.heads = heads
103 self.heads = heads
105 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) -106 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias) -107 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
106 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+107 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
+108 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
110 self.softmax = nn.Softmax(dim=1)
111 self.softmax = nn.Softmax(dim=1)
113 self.output = nn.Linear(d_model, d_model)
114 self.output = nn.Linear(d_model, d_model)
115 self.dropout = nn.Dropout(dropout_prob)
116 self.dropout = nn.Dropout(dropout_prob)
117 self.scale = 1 / math.sqrt(self.d_k)
118 self.scale = 1 / math.sqrt(self.d_k)
120 self.attn = None
121 self.attn = None
122 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
123 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
130 return torch.einsum('ibhd,jbhd->ijbh', query, key)
131 return torch.einsum('ibhd,jbhd->ijbh', query, key)
132 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
133 def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
138 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0] -139 assert mask.shape[1] == key_shape[0] -140 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
139 assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
+140 assert mask.shape[1] == key_shape[0]
+141 assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]
143 mask = mask.unsqueeze(-1)
144 mask = mask.unsqueeze(-1)
146 return mask
147 return mask
148 def forward(self, *, -149 query: torch.Tensor, -150 key: torch.Tensor, -151 value: torch.Tensor, -152 mask: Optional[torch.Tensor] = None):
149 def forward(self, *,
+150 query: torch.Tensor,
+151 key: torch.Tensor,
+152 value: torch.Tensor,
+153 mask: Optional[torch.Tensor] = None):
164 seq_len, batch_size, _ = query.shape -165 -166 if mask is not None: -167 mask = self.prepare_mask(mask, query.shape, key.shape)
165 seq_len, batch_size, _ = query.shape
+166
+167 if mask is not None:
+168 mask = self.prepare_mask(mask, query.shape, key.shape)
171 query = self.query(query) -172 key = self.key(key) -173 value = self.value(value)
172 query = self.query(query)
+173 key = self.key(key)
+174 value = self.value(value)
177 scores = self.get_scores(query, key)
178 scores = self.get_scores(query, key)
180 scores *= self.scale
181 scores *= self.scale
183 if mask is not None: -184 scores = scores.masked_fill(mask == 0, float('-inf'))
184 if mask is not None:
+185 scores = scores.masked_fill(mask == 0, float('-inf'))
188 attn = self.softmax(scores)
189 attn = self.softmax(scores)
191 tracker.debug('attn', attn)
192 tracker.debug('attn', attn)
194 attn = self.dropout(attn)
195 attn = self.dropout(attn)
198 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
199 x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
201 self.attn = attn.detach()
202 self.attn = attn.detach()
204 x = x.reshape(seq_len, batch_size, -1)
205 x = x.reshape(seq_len, batch_size, -1)
207 return self.output(x)
208 return self.output(x)
11import math
-12
-13import torch
-14import torch.nn as nn
-15from labml_helpers.module import Module
-16
-17from labml_nn.utils import clone_module_list
-18from .feed_forward import FeedForward
-19from .mha import MultiHeadAttention
-20from .positional_encoding import get_positional_encoding
14import math
+15
+16import torch
+17import torch.nn as nn
+18from labml_helpers.module import Module
+19
+20from labml_nn.utils import clone_module_list
+21from .feed_forward import FeedForward
+22from .mha import MultiHeadAttention
+23from .positional_encoding import get_positional_encoding
23class EmbeddingsWithPositionalEncoding(Module):
26class EmbeddingsWithPositionalEncoding(Module):
30 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-31 super().__init__()
-32 self.linear = nn.Embedding(n_vocab, d_model)
-33 self.d_model = d_model
-34 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
33 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+34 super().__init__()
+35 self.linear = nn.Embedding(n_vocab, d_model)
+36 self.d_model = d_model
+37 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
36 def forward(self, x: torch.Tensor):
-37 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
-38 return self.linear(x) * math.sqrt(self.d_model) + pe
39 def forward(self, x: torch.Tensor):
+40 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
+41 return self.linear(x) * math.sqrt(self.d_model) + pe
41class EmbeddingsWithLearnedPositionalEncoding(Module):
44class EmbeddingsWithLearnedPositionalEncoding(Module):
48 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
-49 super().__init__()
-50 self.linear = nn.Embedding(n_vocab, d_model)
-51 self.d_model = d_model
-52 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
51 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
+52 super().__init__()
+53 self.linear = nn.Embedding(n_vocab, d_model)
+54 self.d_model = d_model
+55 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
54 def forward(self, x: torch.Tensor):
-55 pe = self.positional_encodings[:x.shape[0]]
-56 return self.linear(x) * math.sqrt(self.d_model) + pe
57 def forward(self, x: torch.Tensor):
+58 pe = self.positional_encodings[:x.shape[0]]
+59 return self.linear(x) * math.sqrt(self.d_model) + pe
59class TransformerLayer(Module):
62class TransformerLayer(Module):
77 def __init__(self, *,
-78 d_model: int,
-79 self_attn: MultiHeadAttention,
-80 src_attn: MultiHeadAttention = None,
-81 feed_forward: FeedForward,
-82 dropout_prob: float):
80 def __init__(self, *,
+81 d_model: int,
+82 self_attn: MultiHeadAttention,
+83 src_attn: MultiHeadAttention = None,
+84 feed_forward: FeedForward,
+85 dropout_prob: float):
90 super().__init__()
-91 self.size = d_model
-92 self.self_attn = self_attn
-93 self.src_attn = src_attn
-94 self.feed_forward = feed_forward
-95 self.dropout = nn.Dropout(dropout_prob)
-96 self.norm_self_attn = nn.LayerNorm([d_model])
-97 if self.src_attn is not None:
-98 self.norm_src_attn = nn.LayerNorm([d_model])
-99 self.norm_ff = nn.LayerNorm([d_model])
93 super().__init__()
+94 self.size = d_model
+95 self.self_attn = self_attn
+96 self.src_attn = src_attn
+97 self.feed_forward = feed_forward
+98 self.dropout = nn.Dropout(dropout_prob)
+99 self.norm_self_attn = nn.LayerNorm([d_model])
+100 if self.src_attn is not None:
+101 self.norm_src_attn = nn.LayerNorm([d_model])
+102 self.norm_ff = nn.LayerNorm([d_model])
101 self.is_save_ff_input = False
104 self.is_save_ff_input = False
103 def forward(self, *,
-104 x: torch.Tensor,
-105 mask: torch.Tensor,
-106 src: torch.Tensor = None,
-107 src_mask: torch.Tensor = None):
106 def forward(self, *,
+107 x: torch.Tensor,
+108 mask: torch.Tensor,
+109 src: torch.Tensor = None,
+110 src_mask: torch.Tensor = None):
109 z = self.norm_self_attn(x)
112 z = self.norm_self_attn(x)
111 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
114 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
113 x = x + self.dropout(self_attn)
116 x = x + self.dropout(self_attn)
118 if src is not None:
121 if src is not None:
120 z = self.norm_src_attn(x)
123 z = self.norm_src_attn(x)
122 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
125 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
124 x = x + self.dropout(attn_src)
127 x = x + self.dropout(attn_src)
127 z = self.norm_ff(x)
130 z = self.norm_ff(x)
129 if self.is_save_ff_input:
-130 self.ff_input = z.clone()
132 if self.is_save_ff_input:
+133 self.ff_input = z.clone()
132 ff = self.feed_forward(z)
135 ff = self.feed_forward(z)
134 x = x + self.dropout(ff)
-135
-136 return x
137 x = x + self.dropout(ff)
+138
+139 return x
139class Encoder(Module):
142class Encoder(Module):
146 def __init__(self, layer: TransformerLayer, n_layers: int):
-147 super().__init__()
149 def __init__(self, layer: TransformerLayer, n_layers: int):
+150 super().__init__()
149 self.layers = clone_module_list(layer, n_layers)
152 self.layers = clone_module_list(layer, n_layers)
151 self.norm = nn.LayerNorm([layer.size])
154 self.norm = nn.LayerNorm([layer.size])
153 def forward(self, x: torch.Tensor, mask: torch.Tensor):
156 def forward(self, x: torch.Tensor, mask: torch.Tensor):
155 for layer in self.layers:
-156 x = layer(x=x, mask=mask)
158 for layer in self.layers:
+159 x = layer(x=x, mask=mask)
158 return self.norm(x)
161 return self.norm(x)
161class Decoder(Module):
164class Decoder(Module):
168 def __init__(self, layer: TransformerLayer, n_layers: int):
-169 super().__init__()
171 def __init__(self, layer: TransformerLayer, n_layers: int):
+172 super().__init__()
171 self.layers = clone_module_list(layer, n_layers)
174 self.layers = clone_module_list(layer, n_layers)
173 self.norm = nn.LayerNorm([layer.size])
176 self.norm = nn.LayerNorm([layer.size])
175 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
178 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
177 for layer in self.layers:
-178 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
180 for layer in self.layers:
+181 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
180 return self.norm(x)
183 return self.norm(x)
183class Generator(Module):
186class Generator(Module):
193 def __init__(self, n_vocab: int, d_model: int):
-194 super().__init__()
-195 self.projection = nn.Linear(d_model, n_vocab)
196 def __init__(self, n_vocab: int, d_model: int):
+197 super().__init__()
+198 self.projection = nn.Linear(d_model, n_vocab)
197 def forward(self, x):
-198 return self.projection(x)
200 def forward(self, x):
+201 return self.projection(x)
201class EncoderDecoder(Module):
204class EncoderDecoder(Module):
208 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
-209 super().__init__()
-210 self.encoder = encoder
-211 self.decoder = decoder
-212 self.src_embed = src_embed
-213 self.tgt_embed = tgt_embed
-214 self.generator = generator
211 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
+212 super().__init__()
+213 self.encoder = encoder
+214 self.decoder = decoder
+215 self.src_embed = src_embed
+216 self.tgt_embed = tgt_embed
+217 self.generator = generator
218 for p in self.parameters():
-219 if p.dim() > 1:
-220 nn.init.xavier_uniform_(p)
221 for p in self.parameters():
+222 if p.dim() > 1:
+223 nn.init.xavier_uniform_(p)
222 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
225 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
224 enc = self.encode(src, src_mask)
227 enc = self.encode(src, src_mask)
226 return self.decode(enc, src_mask, tgt, tgt_mask)
229 return self.decode(enc, src_mask, tgt, tgt_mask)
228 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
-229 return self.encoder(self.src_embed(src), src_mask)
231 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
+232 return self.encoder(self.src_embed(src), src_mask)
231 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
-232 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
234 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
+235 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)