diff --git a/docs/transformers/glu_variants/simple.html b/docs/transformers/glu_variants/simple.html index 207f5c42..ba408fec 100644 --- a/docs/transformers/glu_variants/simple.html +++ b/docs/transformers/glu_variants/simple.html @@ -76,23 +76,25 @@ We try different variants for the position-wise feedforward network.
This is a simpler implementation that doesn’t use labml.configs module.
We decided to write a simpler implementation to make it easier readers who are not familiar.
17import dataclasses
-18
-19import torch
-20from torch import nn
-21from torch.utils.data import Dataset, DataLoader
-22
-23from labml import experiment, lab, tracker, monit, logger
-24from labml.logger import Text
-25from labml.utils.download import download_file
-26from labml_nn.experiments.nlp_autoregression import transpose_batch
-27from labml_nn.optimizers.noam import Noam
-28from labml_nn.transformers import Encoder, MultiHeadAttention
-29from labml_nn.transformers.feed_forward import FeedForward
-30from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
-31from labml_nn.transformers.utils import subsequent_mask20import dataclasses
+21
+22import torch
+23from torch import nn
+24from torch.utils.data import Dataset, DataLoader
+25
+26from labml import experiment, lab, tracker, monit, logger
+27from labml.logger import Text
+28from labml.utils.download import download_file
+29from labml_nn.experiments.nlp_autoregression import transpose_batch
+30from labml_nn.optimizers.noam import Noam
+31from labml_nn.transformers import Encoder, MultiHeadAttention
+32from labml_nn.transformers.feed_forward import FeedForward
+33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
+34from labml_nn.transformers.utils import subsequent_mask34class AutoregressiveModel(nn.Module):37class AutoregressiveModel(nn.Module):39 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
-40 super().__init__()42 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: nn.Module):
+43 super().__init__()Token embedding module
42 self.src_embed = src_embed45 self.src_embed = src_embedTransformer based encoder
44 self.encoder = encoder47 self.encoder = encoder47 self.generator = generator50 self.generator = generatorThis will be initialized on the first call
49 self.src_mask = None52 self.src_mask = None51 def __call__(self, src: torch.Tensor):54 def __call__(self, src: torch.Tensor):Create subsequent mask, so that the transformer can only pay attention to past tokens.
53 if self.src_mask is None or self.src_mask.size(0) != len(src):
-54 self.src_mask = subsequent_mask(len(src)).to(src.device)56 if self.src_mask is None or self.src_mask.size(0) != len(src):
+57 self.src_mask = subsequent_mask(len(src)).to(src.device)Embed the tokens (src) and run it through the the transformer
56 res = self.encoder(self.src_embed(src), self.src_mask)59 res = self.encoder(self.src_embed(src), self.src_mask)Generate logits of the next token
58 return self.generator(res)61 return self.generator(res)61@dataclasses.dataclass
-62class Configs:64@dataclasses.dataclass
+65class Configs:66 d_model: int = 512
-67 seq_len: int = 128
-68 batch_size: int = 32
-69 n_layers: int = 6
-70 n_heads: int = 8
-71 dropout: float = 0.1
-72 d_ff: int = 2048
-73 glu_variant: str = 'GLU'
-74 epochs: int = 5
-75 grad_norm_clip: float = 0.569 d_model: int = 512
+70 seq_len: int = 128
+71 batch_size: int = 32
+72 n_layers: int = 6
+73 n_heads: int = 8
+74 dropout: float = 0.1
+75 d_ff: int = 2048
+76 glu_variant: str = 'GLU'
+77 epochs: int = 5
+78 grad_norm_clip: float = 0.578class TinyShakespeareDataset(Dataset):81class TinyShakespeareDataset(Dataset):83 def __init__(self, seq_len: int):86 def __init__(self, seq_len: int):Location of the text file
85 path = lab.get_data_path() / 'tiny_shakespeare.txt'88 path = lab.get_data_path() / 'tiny_shakespeare.txt'Download the file
87 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)90 download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)Read the downloaded file
89 with open(str(path), 'r') as f:
-90 text = f.read()92 with open(str(path), 'r') as f:
+93 text = f.read()Extract the characters
93 chars = list(set(text))96 chars = list(set(text))Character to id (integer) map
95 self.stoi = {c: i for i, c in enumerate(chars)}98 self.stoi = {c: i for i, c in enumerate(chars)}Id to character map
97 self.itos = {i: c for i, c in enumerate(chars)}100 self.itos = {i: c for i, c in enumerate(chars)}Length of a training sample
99 self.seq_len = seq_len102 self.seq_len = seq_lenData in the form of a tensor of ids
101 self.data = self.text_to_i(text)104 self.data = self.text_to_i(text)Transform the text into a tensor of ids
103 def text_to_i(self, text: str):106 def text_to_i(self, text: str):107 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)110 return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)This will read the dataset seq_len times in a single epoch.
109 def __len__(self):112 def __len__(self):115 return len(self.data) - self.seq_len - 1118 return len(self.data) - self.seq_len - 1Return a sample
117 def __getitem__(self, idx):120 def __getitem__(self, idx):121 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]124 return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]124class Trainer:127class Trainer:129 def __init__(self, configs: Configs):132 def __init__(self, configs: Configs):Get the device
131 self.device = torch.device('cpu')
-132 if torch.cuda.is_available():
-133 self.device = torch.device('cuda:0')134 self.device = torch.device('cpu')
+135 if torch.cuda.is_available():
+136 self.device = torch.device('cuda:0')Initialize the dataset
135 self.dataset = TinyShakespeareDataset(configs.seq_len)138 self.dataset = TinyShakespeareDataset(configs.seq_len)Initialize the dataloader
137 self.dataloader = DataLoader(self.dataset,
-138 batch_size=configs.batch_size,
-139 collate_fn=transpose_batch,
-140 shuffle=True)140 self.dataloader = DataLoader(self.dataset,
+141 batch_size=configs.batch_size,
+142 collate_fn=transpose_batch,
+143 shuffle=True)144 if configs.glu_variant == 'GLU':
-145 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)147 if configs.glu_variant == 'GLU':
+148 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)148 elif configs.glu_variant == 'Bilinear':
-149 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)151 elif configs.glu_variant == 'Bilinear':
+152 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)152 elif configs.glu_variant == 'ReGLU':
-153 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)155 elif configs.glu_variant == 'ReGLU':
+156 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)156 elif configs.glu_variant == 'GEGLU':
-157 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)159 elif configs.glu_variant == 'GEGLU':
+160 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)161 elif configs.glu_variant == 'SwiGLU':
-162 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)164 elif configs.glu_variant == 'SwiGLU':
+165 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)165 elif configs.glu_variant == 'ReLU':
-166 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())168 elif configs.glu_variant == 'ReLU':
+169 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())169 elif configs.glu_variant == 'GELU':
-170 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
-171 else:
-172 raise ValueError(f'Unknown variant {configs.glu_variant}')172 elif configs.glu_variant == 'GELU':
+173 ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
+174 else:
+175 raise ValueError(f'Unknown variant {configs.glu_variant}')Number of different characters
175 n_chars = len(self.dataset.stoi)178 n_chars = len(self.dataset.stoi)Initialize Multi-Head Attention module
178 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)181 mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)Initialize the Transformer Block
180 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
-181 feed_forward=ffn, dropout_prob=configs.dropout)183 transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
+184 feed_forward=ffn, dropout_prob=configs.dropout)187 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
-188 Encoder(transformer_layer, configs.n_layers),
-189 nn.Linear(configs.d_model, n_chars))190 self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
+191 Encoder(transformer_layer, configs.n_layers),
+192 nn.Linear(configs.d_model, n_chars))Move the model to the current device
192 self.model.to(self.device)195 self.model.to(self.device)Initialize Noam optimizer
195 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)198 self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)Cross-entropy loss
198 self.loss_func = nn.CrossEntropyLoss()201 self.loss_func = nn.CrossEntropyLoss()seq_len times in a single epoch
201 self.epochs = configs.epochs204 self.epochs = configs.epochsGradient clipping norm
203 self.grad_norm_clip = configs.grad_norm_clip206 self.grad_norm_clip = configs.grad_norm_clipSet tracker configurations
206 tracker.set_scalar("loss.*", True)209 tracker.set_scalar("loss.*", True)208 def sample(self):211 def sample(self):Starting prompt
214 prompt = 'It is'217 prompt = 'It is'Collect output for printing
216 log = [(prompt, Text.subtle)]219 log = [(prompt, Text.subtle)]Sample 25 tokens
218 for i in monit.iterate('Sample', 25):221 for i in monit.iterate('Sample', 25):Tokenize the prompt
220 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
-221 data = data.to(self.device)223 data = self.dataset.text_to_i(prompt).unsqueeze(-1)
+224 data = data.to(self.device)Get the model output
223 output = self.model(data)226 output = self.model(data)Get the model prediction (greedy)
225 output = output.argmax(dim=-1).squeeze()228 output = output.argmax(dim=-1).squeeze()Add the prediction to prompt
227 prompt += self.dataset.itos[output[-1].item()]230 prompt += self.dataset.itos[output[-1].item()]Add the prediction for logging
229 log += [(self.dataset.itos[output[-1].item()], Text.value)]232 log += [(self.dataset.itos[output[-1].item()], Text.value)]Print the sampled output
232 logger.log(log)235 logger.log(log)234 def train(self):237 def train(self):Loop for the given number of epochs
240 for _ in monit.loop(self.epochs):243 for _ in monit.loop(self.epochs):Iterate over the minibatches
242 for i, batch in monit.enum('Train', self.dataloader):245 for i, batch in monit.enum('Train', self.dataloader):Move data to the device
244 data, target = batch[0].to(self.device), batch[1].to(self.device)247 data, target = batch[0].to(self.device), batch[1].to(self.device)Set tracker step, as the number of characters trained on
247 tracker.add_global_step(data.shape[0] * data.shape[1])250 tracker.add_global_step(data.shape[0] * data.shape[1])Set model state to training
250 self.model.train()253 self.model.train()Evaluate the model
252 output = self.model(data)255 output = self.model(data)Calculate loss
255 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))258 loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))Log the loss
257 tracker.add("loss.train", loss)260 tracker.add("loss.train", loss)Calculate gradients
260 loss.backward()263 loss.backward()Clip gradients
262 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)265 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Take optimizer step
264 self.optimizer.step()267 self.optimizer.step()Log the model parameters and gradients
266 if (i + 1) % 100 == 0:
-267 tracker.add('model', self.model)269 if (i + 1) % 100 == 0:
+270 tracker.add('model', self.model)Clear the gradients
269 self.optimizer.zero_grad()272 self.optimizer.zero_grad()Generate a sample
272 if (i + 1) % 100 == 0:
-273 self.model.eval()
-274 with torch.no_grad():
-275 self.sample()275 if (i + 1) % 100 == 0:
+276 self.model.eval()
+277 with torch.no_grad():
+278 self.sample()Save the tracked metrics
278 if (i + 1) % 10 == 0:
-279 tracker.save()281 if (i + 1) % 10 == 0:
+282 tracker.save()Save the model
282 experiment.save_checkpoint()285 experiment.save_checkpoint()285def main():288def main():Create experiment
287 experiment.create(name="glu_variants")290 experiment.create(name="glu_variants")Create configs
289 configs = Configs()292 configs = Configs()Load configurations
291 experiment.configs(dataclasses.asdict(configs))294 experiment.configs(dataclasses.asdict(configs))Create trainer
294 trainer = Trainer(configs)297 trainer = Trainer(configs)Set models for training and loading
296 experiment.add_pytorch_models({'model': trainer.model})299 experiment.add_pytorch_models({'model': trainer.model})Start the experiment
299 with experiment.start():302 with experiment.start():Train the model
301 trainer.train()
-302
-303
-304if __name__ == '__main__':
-305 main()304 trainer.train()
+305
+306
+307if __name__ == '__main__':
+308 main()\n", + "glu_variants: 86b773f65fc911ebb2ac0242ac1c0002\n", + "\t[dirty]: \"\"\n", + "\n", + "--------------------------------------------------\n", + "LABML WARNING\n", + "LabML App Warning: empty_token: Please create a valid token at https://web.lab-ml.com.\n", + "Click on the experiment link to monitor the experiment and add it to your experiments list.\n", + "--------------------------------------------------\n", + "Monitor experiment at https://web.lab-ml.com/run?uuid=86b773f65fc911ebb2ac0242ac1c0002\n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "It is the the the the the the \n", + "1,925,120: Train: 1% 8,427,381ms loss.train: 2.42505 8,427,381ms 0:01m/ 11:40m" + ], + "text/plain": [ + "