diff --git a/docs/transformers/feed_forward.html b/docs/transformers/feed_forward.html index 145ec521..9420233f 100644 --- a/docs/transformers/feed_forward.html +++ b/docs/transformers/feed_forward.html @@ -226,7 +226,7 @@ be multiplied by the gate, parameterized by weight $V$ and bias $c$
81 def __call__(self, x: torch.Tensor):
81 def forward(self, x: torch.Tensor):
44 def __call__(self, x: torch.Tensor):
44 def forward(self, x: torch.Tensor):
158 def __call__(self, *,
+ 158 def forward(self, *,
159 query: torch.Tensor,
160 key: torch.Tensor,
161 value: torch.Tensor):
@@ -609,7 +609,7 @@ Results in a tensor of shape [seq_len, batch_size, heads]
229 def __call__(self, *,
+ 229 def forward(self, *,
230 x: torch.Tensor,
231 key: Optional[torch.Tensor],
232 value: Optional[torch.Tensor]):
@@ -794,7 +794,7 @@ This is the weights parameter for that.
275 def __call__(self, x_seq: torch.Tensor):
275 def forward(self, x_seq: torch.Tensor):
473 def __call__(self, x_seq: torch.Tensor):
473 def forward(self, x_seq: torch.Tensor):
44 def __call__(self, src: torch.Tensor):
44 def forward(self, src: torch.Tensor):
20import dataclasses
21
22import torch
-23from torch import nn
-24from torch.utils.data import Dataset, DataLoader
-25
-26from labml import experiment, lab, tracker, monit, logger
-27from labml.logger import Text
-28from labml.utils.download import download_file
-29from labml_nn.experiments.nlp_autoregression import transpose_batch
-30from labml_nn.optimizers.noam import Noam
-31from labml_nn.transformers import Encoder, MultiHeadAttention
-32from labml_nn.transformers.feed_forward import FeedForward
-33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
-34from labml_nn.transformers.utils import subsequent_mask
37class AutoregressiveModel(nn.Module):
38class AutoregressiveModel(Module):
42 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
-43 super().__init__()
43 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
+44 super().__init__()
Token embedding module
45 self.src_embed = src_embed
46 self.src_embed = src_embed
Transformer based encoder
47 self.encoder = encoder
48 self.encoder = encoder
50 self.generator = generator
51 self.generator = generator
This will be initialized on the first call
52 self.src_mask = None
53 self.src_mask = None
54 def __call__(self, src: torch.Tensor):
55 def forward(self, src: torch.Tensor):
Create subsequent mask, so that the transformer can only pay attention to past tokens.
56 if self.src_mask is None or self.src_mask.size(0) != len(src):
-57 self.src_mask = subsequent_mask(len(src)).to(src.device)
57 if self.src_mask is None or self.src_mask.size(0) != len(src):
+58 self.src_mask = subsequent_mask(len(src)).to(src.device)
Embed the tokens (src
) and run it through the the transformer
59 res = self.encoder(self.src_embed(src), self.src_mask)
60 res = self.encoder(self.src_embed(src), self.src_mask)
Generate logits of the next token
61 return self.generator(res)
62 return self.generator(res)
64@dataclasses.dataclass
-65class Configs:
65@dataclasses.dataclass
+66class Configs:
69 d_model: int = 512
-70 seq_len: int = 128
-71 batch_size: int = 32
-72 n_layers: int = 6
-73 n_heads: int = 8
-74 dropout: float = 0.1
-75 d_ff: int = 2048
-76 glu_variant: str = 'GLU'
-77 epochs: int = 5
-78 grad_norm_clip: float = 0.5
70 d_model: int = 512
+71 seq_len: int = 128
+72 batch_size: int = 32
+73 n_layers: int = 6
+74 n_heads: int = 8
+75 dropout: float = 0.1
+76 d_ff: int = 2048
+77 glu_variant: str = 'GLU'
+78 epochs: int = 5
+79 grad_norm_clip: float = 0.5
81class TinyShakespeareDataset(Dataset):
82class TinyShakespeareDataset(Dataset):
86 def __init__(self, seq_len: int):
87 def __init__(self, seq_len: int):
Location of the text file
88 path = lab.get_data_path() / 'tiny_shakespeare.txt'
89 path = lab.get_data_path() / 'tiny_shakespeare.txt'
Download the file
90 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
91 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)
Read the downloaded file
92 with open(str(path), 'r') as f:
-93 text = f.read()
93 with open(str(path), 'r') as f:
+94 text = f.read()
Extract the characters
96 chars = list(set(text))
97 chars = list(set(text))
Character to id (integer) map
98 self.stoi = {c: i for i, c in enumerate(chars)}
99 self.stoi = {c: i for i, c in enumerate(chars)}
Id to character map
100 self.itos = {i: c for i, c in enumerate(chars)}
101 self.itos = {i: c for i, c in enumerate(chars)}
Length of a training sample
102 self.seq_len = seq_len
103 self.seq_len = seq_len
Data in the form of a tensor of ids
104 self.data = self.text_to_i(text)
105 self.data = self.text_to_i(text)
Transform the text into a tensor of ids
106 def text_to_i(self, text: str):
107 def text_to_i(self, text: str):
110 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
111 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)
This will read the dataset seq_len
times in a single epoch.
112 def __len__(self):
113 def __len__(self):
118 return len(self.data) - self.seq_len - 1
119 return len(self.data) - self.seq_len - 1
Return a sample
120 def __getitem__(self, idx):
121 def __getitem__(self, idx):
124 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
125 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]
127class Trainer:
128class Trainer:
132 def __init__(self, configs: Configs):
133 def __init__(self, configs: Configs):
Get the device
134 self.device = torch.device('cpu')
-135 if torch.cuda.is_available():
-136 self.device = torch.device('cuda:0')
135 self.device = torch.device('cpu')
+136 if torch.cuda.is_available():
+137 self.device = torch.device('cuda:0')
Initialize the dataset
138 self.dataset = TinyShakespeareDataset(configs.seq_len)
139 self.dataset = TinyShakespeareDataset(configs.seq_len)
Initialize the dataloader
140 self.dataloader = DataLoader(self.dataset,
-141 batch_size=configs.batch_size,
-142 collate_fn=transpose_batch,
-143 shuffle=True)
141 self.dataloader = DataLoader(self.dataset,
+142 batch_size=configs.batch_size,
+143 collate_fn=transpose_batch,
+144 shuffle=True)
147 if configs.glu_variant == 'GLU':
-148 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
148 if configs.glu_variant == 'GLU':
+149 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)
151 elif configs.glu_variant == 'Bilinear':
-152 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
152 elif configs.glu_variant == 'Bilinear':
+153 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)
155 elif configs.glu_variant == 'ReGLU':
-156 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
156 elif configs.glu_variant == 'ReGLU':
+157 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)
159 elif configs.glu_variant == 'GEGLU':
-160 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
160 elif configs.glu_variant == 'GEGLU':
+161 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)
164 elif configs.glu_variant == 'SwiGLU':
-165 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
165 elif configs.glu_variant == 'SwiGLU':
+166 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)
168 elif configs.glu_variant == 'ReLU':
-169 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
169 elif configs.glu_variant == 'ReLU':
+170 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())
172 elif configs.glu_variant == 'GELU':
-173 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
-174 else:
-175 raise ValueError(f'Unknown variant {configs.glu_variant}')
173 elif configs.glu_variant == 'GELU':
+174 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
+175 else:
+176 raise ValueError(f'Unknown variant {configs.glu_variant}')
Number of different characters
178 n_chars = len(self.dataset.stoi)
179 n_chars = len(self.dataset.stoi)
Initialize Multi-Head Attention module
181 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
182 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)
Initialize the Transformer Block
183 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
-184 feed_forward=ffn, dropout_prob=configs.dropout)
184 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
+185 feed_forward=ffn, dropout_prob=configs.dropout)
190 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
-191 Encoder(transformer_layer, configs.n_layers),
-192 nn.Linear(configs.d_model, n_chars))
191 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
+192 Encoder(transformer_layer, configs.n_layers),
+193 nn.Linear(configs.d_model, n_chars))
Move the model to the current device
195 self.model.to(self.device)
196 self.model.to(self.device)
Initialize Noam optimizer
198 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
199 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)
Cross-entropy loss
201 self.loss_func = nn.CrossEntropyLoss()
202 self.loss_func = nn.CrossEntropyLoss()
seq_len
times in a single epoch
204 self.epochs = configs.epochs
205 self.epochs = configs.epochs
Gradient clipping norm
206 self.grad_norm_clip = configs.grad_norm_clip
207 self.grad_norm_clip = configs.grad_norm_clip
Set tracker configurations
209 tracker.set_scalar("loss.*", True)
210 tracker.set_scalar("loss.*", True)
211 def sample(self):
212 def sample(self):
Starting prompt
217 prompt = 'It is'
218 prompt = 'It is'
Collect output for printing
219 log = [(prompt, Text.subtle)]
220 log = [(prompt, Text.subtle)]
Sample 25 tokens
221 for i in monit.iterate('Sample', 25):
222 for i in monit.iterate('Sample', 25):
Tokenize the prompt
223 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
-224 data = data.to(self.device)
224 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
+225 data = data.to(self.device)
Get the model output
226 output = self.model(data)
227 output = self.model(data)
Get the model prediction (greedy)
228 output = output.argmax(dim=-1).squeeze()
229 output = output.argmax(dim=-1).squeeze()
Add the prediction to prompt
230 prompt += self.dataset.itos[output[-1].item()]
231 prompt += self.dataset.itos[output[-1].item()]
Add the prediction for logging
232 log += [(self.dataset.itos[output[-1].item()], Text.value)]
233 log += [(self.dataset.itos[output[-1].item()], Text.value)]
Print the sampled output
235 logger.log(log)
236 logger.log(log)
237 def train(self):
238 def train(self):
Loop for the given number of epochs
243 for _ in monit.loop(self.epochs):
244 for _ in monit.loop(self.epochs):
Iterate over the minibatches
245 for i, batch in monit.enum('Train', self.dataloader):
246 for i, batch in monit.enum('Train', self.dataloader):
Move data to the device
247 data, target = batch[0].to(self.device), batch[1].to(self.device)
248 data, target = batch[0].to(self.device), batch[1].to(self.device)
Set tracker step, as the number of characters trained on
250 tracker.add_global_step(data.shape[0] * data.shape[1])
251 tracker.add_global_step(data.shape[0] * data.shape[1])
Set model state to training
253 self.model.train()
254 self.model.train()
Evaluate the model
255 output = self.model(data)
256 output = self.model(data)
Calculate loss
258 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
259 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))
Log the loss
260 tracker.add("loss.train", loss)
261 tracker.add("loss.train", loss)
Calculate gradients
263 loss.backward()
264 loss.backward()
Clip gradients
265 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
266 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
267 self.optimizer.step()
268 self.optimizer.step()
Log the model parameters and gradients
269 if (i + 1) % 100 == 0:
-270 tracker.add('model', self.model)
270 if (i + 1) % 100 == 0:
+271 tracker.add('model', self.model)
Clear the gradients
272 self.optimizer.zero_grad()
273 self.optimizer.zero_grad()
Generate a sample
275 if (i + 1) % 100 == 0:
-276 self.model.eval()
-277 with torch.no_grad():
-278 self.sample()
276 if (i + 1) % 100 == 0:
+277 self.model.eval()
+278 with torch.no_grad():
+279 self.sample()
Save the tracked metrics
281 if (i + 1) % 10 == 0:
-282 tracker.save()
282 if (i + 1) % 10 == 0:
+283 tracker.save()
Save the model
285 experiment.save_checkpoint()
286 experiment.save_checkpoint()
288def main():
289def main():
Create experiment
290 experiment.create(name="glu_variants")
291 experiment.create(name="glu_variants")
Create configs
292 configs = Configs()
293 configs = Configs()
Load configurations
294 experiment.configs(dataclasses.asdict(configs))
295 experiment.configs(dataclasses.asdict(configs))
Create trainer
297 trainer = Trainer(configs)
298 trainer = Trainer(configs)
Set models for training and loading
299 experiment.add_pytorch_models({'model': trainer.model})
300 experiment.add_pytorch_models({'model': trainer.model})
Start the experiment
302 with experiment.start():
303 with experiment.start():
Train the model
304 trainer.train()
-305
+ 305 trainer.train()
306
-307if __name__ == '__main__':
-308 main()
+307
+308if __name__ == '__main__':
+309 main()
70 def __call__(self, x: torch.Tensor):
70 def forward(self, x: torch.Tensor):
52 def __call__(self, src: torch.Tensor):
52 def forward(self, src: torch.Tensor):
29 def __call__(self, x: torch.Tensor, target: torch.Tensor):
+ 29 def forward(self, x: torch.Tensor, target: torch.Tensor):
30 assert x.shape[1] == self.size
31 true_dist = x.clone()
32 true_dist.fill_(self.smoothing / (self.size - 2))
diff --git a/docs/transformers/mha.html b/docs/transformers/mha.html
index 6c19e8f1..9f3f1478 100644
--- a/docs/transformers/mha.html
+++ b/docs/transformers/mha.html
@@ -155,7 +155,7 @@ This is used to transform key, query, and
- 45 def __call__(self, x: torch.Tensor):
+ 45 def forward(self, x: torch.Tensor):
@@ -377,7 +377,7 @@ They have shape [seq_len, batch_size, d_model]
.
query at position i
has access to key-value at position j
.
- 121 def __call__(self, *,
+ 121 def forward(self, *,
122 query: torch.Tensor,
123 key: torch.Tensor,
124 value: torch.Tensor,
diff --git a/docs/transformers/models.html b/docs/transformers/models.html
index 5ea0ba11..15f793cb 100644
--- a/docs/transformers/models.html
+++ b/docs/transformers/models.html
@@ -122,7 +122,7 @@
- 36 def __call__(self, x: torch.Tensor):
+ 36 def forward(self, x: torch.Tensor):
37 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
38 return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -163,7 +163,7 @@
- 54 def __call__(self, x: torch.Tensor):
+ 54 def forward(self, x: torch.Tensor):
55 pe = self.positional_encodings[:x.shape[0]]
56 return self.linear(x) * math.sqrt(self.d_model) + pe
@@ -251,7 +251,7 @@ We found a detailed discussion about this in the paper
- 103 def __call__(self, *,
+ 103 def forward(self, *,
104 x: torch.Tensor,
105 mask: torch.Tensor,
106 src: torch.Tensor = None,
@@ -439,7 +439,7 @@ encoder outputs
- 153 def __call__(self, x: torch.Tensor, mask: torch.Tensor):
+ 153 def forward(self, x: torch.Tensor, mask: torch.Tensor):
@@ -520,7 +520,7 @@ encoder outputs
- 175 def __call__(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+ 175 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
@@ -582,7 +582,7 @@ You don’t need this if you are using nn.CrossEntropyLoss
.
- 197 def __call__(self, x):
+ 197 def forward(self, x):
198 return self.projection(x)
@@ -638,7 +638,7 @@ Initialize parameters with Glorot / fan_avg.
- 222 def __call__(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
+ 222 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
diff --git a/docs/transformers/positional_encoding.html b/docs/transformers/positional_encoding.html
index 9f1996dc..584ee620 100644
--- a/docs/transformers/positional_encoding.html
+++ b/docs/transformers/positional_encoding.html
@@ -127,7 +127,7 @@ PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)
- 39 def __call__(self, x: torch.Tensor):
+ 39 def forward(self, x: torch.Tensor):
40 pe = self.positional_encodings[:x.shape[0]].detach().requires_grad_(False)
41 x = x + pe
42 x = self.dropout(x)
diff --git a/docs/transformers/switch/experiment.html b/docs/transformers/switch/experiment.html
index fa82fc68..ba225c83 100644
--- a/docs/transformers/switch/experiment.html
+++ b/docs/transformers/switch/experiment.html
@@ -151,7 +151,7 @@
- 37 def __call__(self, x: torch.Tensor):
+ 37 def forward(self, x: torch.Tensor):
diff --git a/docs/transformers/switch/index.html b/docs/transformers/switch/index.html
index bb4d2ab6..23e6481b 100644
--- a/docs/transformers/switch/index.html
+++ b/docs/transformers/switch/index.html
@@ -192,7 +192,7 @@ discusses dropping tokens when routing is not balanced.
- 83 def __call__(self, x: torch.Tensor):
+ 83 def forward(self, x: torch.Tensor):
@@ -525,7 +525,7 @@ with handling extra outputs of switch feedforward module.
- 192 def __call__(self, *,
+ 192 def forward(self, *,
193 x: torch.Tensor,
194 mask: torch.Tensor):
@@ -651,7 +651,7 @@ with handling extra outputs of switch feedforward module.
- 224 def __call__(self, x: torch.Tensor, mask: torch.Tensor):
+ 224 def forward(self, x: torch.Tensor, mask: torch.Tensor):