mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	auto regression common exp
This commit is contained in:
		| @ -44,13 +44,14 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): | ||||
|  | ||||
|     is_save_models = True | ||||
|  | ||||
|     loss_func: CrossEntropyLoss() | ||||
|     loss_func = CrossEntropyLoss() | ||||
|     accuracy = Accuracy() | ||||
|  | ||||
|     def init(self): | ||||
|         tracker.set_queue("loss.*", 20, True) | ||||
|         tracker.set_scalar("accuracy.*", True) | ||||
|         hook_model_outputs(self.mode, self.model, 'model') | ||||
|         self.state_modules = [Accuracy()] | ||||
|         self.state_modules = [self.accuracy] | ||||
|  | ||||
|     def step(self, batch: any, batch_idx: BatchIndex): | ||||
|         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||
| @ -62,8 +63,8 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): | ||||
|             output, *_ = self.model(data) | ||||
|  | ||||
|         loss = self.loss_func(output, target) | ||||
|         self.accuracy_func(output, target) | ||||
|         self.accuracy_func.track() | ||||
|         self.accuracy(output, target) | ||||
|         self.accuracy.track() | ||||
|         tracker.add("loss.", loss) | ||||
|  | ||||
|         if self.mode.is_train: | ||||
| @ -88,7 +89,7 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): | ||||
|             data = self.text.text_to_i(prompt).unsqueeze(-1) | ||||
|             data = data.to(self.device) | ||||
|             # Get the model output | ||||
|             output = self.model(data) | ||||
|             output, *_ = self.model(data) | ||||
|             # Get the model prediction (greedy) | ||||
|             output = output.argmax(dim=-1).squeeze() | ||||
|             # Add the prediction to prompt | ||||
|  | ||||
| @ -1,18 +1,11 @@ | ||||
| from typing import Callable, Any | ||||
|  | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from labml import lab, experiment, monit, tracker, logger | ||||
| from labml import experiment | ||||
| from labml.configs import option | ||||
| from labml.logger import Text | ||||
| from labml.utils.pytorch import get_modules | ||||
| from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset | ||||
| from labml_helpers.metrics.accuracy import Accuracy | ||||
| from labml_helpers.module import Module | ||||
| from labml_helpers.optimizer import OptimizerConfigs | ||||
| from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | ||||
| from torch.utils.data import DataLoader | ||||
|  | ||||
| from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs | ||||
| from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | ||||
|  | ||||
|  | ||||
| @ -36,28 +29,7 @@ class AutoregressiveModel(Module): | ||||
|         return self.generator(res), state | ||||
|  | ||||
|  | ||||
| class CrossEntropyLoss(Module): | ||||
|     """ | ||||
|     Cross entropy loss | ||||
|     """ | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self.loss = nn.CrossEntropyLoss() | ||||
|  | ||||
|     def __call__(self, outputs, targets): | ||||
|         return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) | ||||
|  | ||||
|  | ||||
| def transpose_batch(batch): | ||||
|     transposed_data = list(zip(*batch)) | ||||
|     src = torch.stack(transposed_data[0], 1) | ||||
|     tgt = torch.stack(transposed_data[1], 1) | ||||
|  | ||||
|     return src, tgt | ||||
|  | ||||
|  | ||||
| class Configs(SimpleTrainValidConfigs): | ||||
| class Configs(NLPAutoRegressionConfigs): | ||||
|     """ | ||||
|     ## Configurations | ||||
|  | ||||
| @ -65,119 +37,6 @@ class Configs(SimpleTrainValidConfigs): | ||||
|     """ | ||||
|  | ||||
|     model: AutoregressiveModel | ||||
|     text: TextDataset | ||||
|     batch_size: int = 20 | ||||
|     seq_len: int = 512 | ||||
|     n_tokens: int | ||||
|     tokenizer: Callable = 'character' | ||||
|  | ||||
|     is_save_models = True | ||||
|  | ||||
|     optimizer: torch.optim.Adam = 'transformer_optimizer' | ||||
|  | ||||
|     accuracy = Accuracy() | ||||
|     loss_func = CrossEntropyLoss() | ||||
|  | ||||
|     def init(self): | ||||
|         # Create a configurable optimizer. | ||||
|         # Parameters like learning rate can be changed by passing a dictionary when starting the experiment. | ||||
|         optimizer = OptimizerConfigs() | ||||
|         optimizer.parameters = self.model.parameters() | ||||
|         optimizer.optimizer = 'Adam' | ||||
|         self.optimizer = optimizer | ||||
|  | ||||
|         # Create a sequential data loader for training | ||||
|         self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train, | ||||
|                                                                   dataset=self.text, | ||||
|                                                                   seq_len=self.seq_len), | ||||
|                                        batch_size=self.batch_size, | ||||
|                                        collate_fn=transpose_batch, | ||||
|                                        shuffle=True) | ||||
|  | ||||
|         # Create a sequential data loader for validation | ||||
|         self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid, | ||||
|                                                                   dataset=self.text, | ||||
|                                                                   seq_len=self.seq_len), | ||||
|                                        batch_size=self.batch_size, | ||||
|                                        collate_fn=transpose_batch, | ||||
|                                        shuffle=True) | ||||
|  | ||||
|         self.state_modules = [self.accuracy] | ||||
|  | ||||
|     def sample(self): | ||||
|         """ | ||||
|         Sampling function to generate samples periodically while training | ||||
|         """ | ||||
|         prompt = 'It is' | ||||
|         log = [(prompt, Text.subtle)] | ||||
|         # Sample 25 tokens | ||||
|         for i in monit.iterate('Sample', 25): | ||||
|             # Tokenize the prompt | ||||
|             data = self.text.text_to_i(prompt).unsqueeze(-1) | ||||
|             data = data.to(self.device) | ||||
|             # Get the model output | ||||
|             output, state = self.model(data) | ||||
|             output = output.cpu() | ||||
|             # Get the model prediction (greedy) | ||||
|             output = output.argmax(dim=-1).squeeze() | ||||
|             # Add the prediction to prompt | ||||
|             prompt += self.text.itos[output[-1]] | ||||
|             # Add the prediction for logging | ||||
|             log += [(self.text.itos[output[-1]], Text.value)] | ||||
|  | ||||
|         logger.log(log) | ||||
|  | ||||
|     def step(self, batch: Any, batch_idx: BatchIndex): | ||||
|         """ | ||||
|         This method is called for each batch | ||||
|         """ | ||||
|         self.model.train(self.mode.is_train) | ||||
|  | ||||
|         # Get data and target labels | ||||
|         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||
|  | ||||
|         if self.mode.is_train: | ||||
|             tracker.add_global_step(data.shape[0] * data.shape[1]) | ||||
|  | ||||
|         # Run the model | ||||
|         output, state = self.model(data) | ||||
|  | ||||
|         # Calculate loss | ||||
|         loss = self.loss_func(output, target) | ||||
|         # Calculate accuracy | ||||
|         self.accuracy(output, target) | ||||
|  | ||||
|         # Log the loss | ||||
|         tracker.add("loss.", loss) | ||||
|  | ||||
|         #  If we are in training mode, calculate the gradients | ||||
|         if self.mode.is_train: | ||||
|             loss.backward() | ||||
|             self.optimizer.step() | ||||
|             if batch_idx.is_last: | ||||
|                 tracker.add('model', self.model) | ||||
|             self.optimizer.zero_grad() | ||||
|  | ||||
|         tracker.save() | ||||
|  | ||||
|  | ||||
| def character_tokenizer(x: str): | ||||
|     return list(x) | ||||
|  | ||||
|  | ||||
| @option(Configs.tokenizer) | ||||
| def character(): | ||||
|     """ | ||||
|     Character level tokenizer | ||||
|     """ | ||||
|     return character_tokenizer | ||||
|  | ||||
|  | ||||
| @option(Configs.text) | ||||
| def tiny_shakespeare(c: Configs): | ||||
|     return TextFileDataset( | ||||
|         lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer, | ||||
|         url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') | ||||
|  | ||||
|  | ||||
| @option(Configs.model) | ||||
| @ -191,7 +50,7 @@ def autoregressive_model(c: Configs): | ||||
|  | ||||
| def main(): | ||||
|     # Create experiment | ||||
|     experiment.create(name="knn_lm", comment='') | ||||
|     experiment.create(name="hyper_lstm", comment='') | ||||
|     # Create configs | ||||
|     conf = Configs() | ||||
|     # Load configurations | ||||
| @ -200,6 +59,12 @@ def main(): | ||||
|                        {'tokenizer': 'character', | ||||
|                         'text': 'tiny_shakespeare', | ||||
|                         'optimizer.learning_rate': 2.5e-4, | ||||
|                         'optimizer.optimizer': 'Adam', | ||||
|                         'prompt': 'It is', | ||||
|                         'prompt_separator': '', | ||||
|  | ||||
|                         'train_loader': 'shuffled_train_loader', | ||||
|                         'valid_loader': 'shuffled_valid_loader', | ||||
|  | ||||
|                         'seq_len': 512, | ||||
|                         'epochs': 128, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri