mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	auto regression common exp
This commit is contained in:
		| @ -44,13 +44,14 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): | |||||||
|  |  | ||||||
|     is_save_models = True |     is_save_models = True | ||||||
|  |  | ||||||
|     loss_func: CrossEntropyLoss() |     loss_func = CrossEntropyLoss() | ||||||
|  |     accuracy = Accuracy() | ||||||
|  |  | ||||||
|     def init(self): |     def init(self): | ||||||
|         tracker.set_queue("loss.*", 20, True) |         tracker.set_queue("loss.*", 20, True) | ||||||
|         tracker.set_scalar("accuracy.*", True) |         tracker.set_scalar("accuracy.*", True) | ||||||
|         hook_model_outputs(self.mode, self.model, 'model') |         hook_model_outputs(self.mode, self.model, 'model') | ||||||
|         self.state_modules = [Accuracy()] |         self.state_modules = [self.accuracy] | ||||||
|  |  | ||||||
|     def step(self, batch: any, batch_idx: BatchIndex): |     def step(self, batch: any, batch_idx: BatchIndex): | ||||||
|         data, target = batch[0].to(self.device), batch[1].to(self.device) |         data, target = batch[0].to(self.device), batch[1].to(self.device) | ||||||
| @ -62,8 +63,8 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): | |||||||
|             output, *_ = self.model(data) |             output, *_ = self.model(data) | ||||||
|  |  | ||||||
|         loss = self.loss_func(output, target) |         loss = self.loss_func(output, target) | ||||||
|         self.accuracy_func(output, target) |         self.accuracy(output, target) | ||||||
|         self.accuracy_func.track() |         self.accuracy.track() | ||||||
|         tracker.add("loss.", loss) |         tracker.add("loss.", loss) | ||||||
|  |  | ||||||
|         if self.mode.is_train: |         if self.mode.is_train: | ||||||
| @ -88,7 +89,7 @@ class NLPAutoRegressionConfigs(TrainValidConfigs): | |||||||
|             data = self.text.text_to_i(prompt).unsqueeze(-1) |             data = self.text.text_to_i(prompt).unsqueeze(-1) | ||||||
|             data = data.to(self.device) |             data = data.to(self.device) | ||||||
|             # Get the model output |             # Get the model output | ||||||
|             output = self.model(data) |             output, *_ = self.model(data) | ||||||
|             # Get the model prediction (greedy) |             # Get the model prediction (greedy) | ||||||
|             output = output.argmax(dim=-1).squeeze() |             output = output.argmax(dim=-1).squeeze() | ||||||
|             # Add the prediction to prompt |             # Add the prediction to prompt | ||||||
|  | |||||||
| @ -1,18 +1,11 @@ | |||||||
| from typing import Callable, Any |  | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from labml import lab, experiment, monit, tracker, logger | from labml import experiment | ||||||
| from labml.configs import option | from labml.configs import option | ||||||
| from labml.logger import Text |  | ||||||
| from labml.utils.pytorch import get_modules | from labml.utils.pytorch import get_modules | ||||||
| from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset |  | ||||||
| from labml_helpers.metrics.accuracy import Accuracy |  | ||||||
| from labml_helpers.module import Module | from labml_helpers.module import Module | ||||||
| from labml_helpers.optimizer import OptimizerConfigs |  | ||||||
| from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex |  | ||||||
| from torch.utils.data import DataLoader |  | ||||||
|  |  | ||||||
|  | from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs | ||||||
| from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | from labml_nn.hypernetworks.hyper_lstm import HyperLSTM | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -36,28 +29,7 @@ class AutoregressiveModel(Module): | |||||||
|         return self.generator(res), state |         return self.generator(res), state | ||||||
|  |  | ||||||
|  |  | ||||||
| class CrossEntropyLoss(Module): | class Configs(NLPAutoRegressionConfigs): | ||||||
|     """ |  | ||||||
|     Cross entropy loss |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def __init__(self): |  | ||||||
|         super().__init__() |  | ||||||
|         self.loss = nn.CrossEntropyLoss() |  | ||||||
|  |  | ||||||
|     def __call__(self, outputs, targets): |  | ||||||
|         return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def transpose_batch(batch): |  | ||||||
|     transposed_data = list(zip(*batch)) |  | ||||||
|     src = torch.stack(transposed_data[0], 1) |  | ||||||
|     tgt = torch.stack(transposed_data[1], 1) |  | ||||||
|  |  | ||||||
|     return src, tgt |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Configs(SimpleTrainValidConfigs): |  | ||||||
|     """ |     """ | ||||||
|     ## Configurations |     ## Configurations | ||||||
|  |  | ||||||
| @ -65,119 +37,6 @@ class Configs(SimpleTrainValidConfigs): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     model: AutoregressiveModel |     model: AutoregressiveModel | ||||||
|     text: TextDataset |  | ||||||
|     batch_size: int = 20 |  | ||||||
|     seq_len: int = 512 |  | ||||||
|     n_tokens: int |  | ||||||
|     tokenizer: Callable = 'character' |  | ||||||
|  |  | ||||||
|     is_save_models = True |  | ||||||
|  |  | ||||||
|     optimizer: torch.optim.Adam = 'transformer_optimizer' |  | ||||||
|  |  | ||||||
|     accuracy = Accuracy() |  | ||||||
|     loss_func = CrossEntropyLoss() |  | ||||||
|  |  | ||||||
|     def init(self): |  | ||||||
|         # Create a configurable optimizer. |  | ||||||
|         # Parameters like learning rate can be changed by passing a dictionary when starting the experiment. |  | ||||||
|         optimizer = OptimizerConfigs() |  | ||||||
|         optimizer.parameters = self.model.parameters() |  | ||||||
|         optimizer.optimizer = 'Adam' |  | ||||||
|         self.optimizer = optimizer |  | ||||||
|  |  | ||||||
|         # Create a sequential data loader for training |  | ||||||
|         self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train, |  | ||||||
|                                                                   dataset=self.text, |  | ||||||
|                                                                   seq_len=self.seq_len), |  | ||||||
|                                        batch_size=self.batch_size, |  | ||||||
|                                        collate_fn=transpose_batch, |  | ||||||
|                                        shuffle=True) |  | ||||||
|  |  | ||||||
|         # Create a sequential data loader for validation |  | ||||||
|         self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid, |  | ||||||
|                                                                   dataset=self.text, |  | ||||||
|                                                                   seq_len=self.seq_len), |  | ||||||
|                                        batch_size=self.batch_size, |  | ||||||
|                                        collate_fn=transpose_batch, |  | ||||||
|                                        shuffle=True) |  | ||||||
|  |  | ||||||
|         self.state_modules = [self.accuracy] |  | ||||||
|  |  | ||||||
|     def sample(self): |  | ||||||
|         """ |  | ||||||
|         Sampling function to generate samples periodically while training |  | ||||||
|         """ |  | ||||||
|         prompt = 'It is' |  | ||||||
|         log = [(prompt, Text.subtle)] |  | ||||||
|         # Sample 25 tokens |  | ||||||
|         for i in monit.iterate('Sample', 25): |  | ||||||
|             # Tokenize the prompt |  | ||||||
|             data = self.text.text_to_i(prompt).unsqueeze(-1) |  | ||||||
|             data = data.to(self.device) |  | ||||||
|             # Get the model output |  | ||||||
|             output, state = self.model(data) |  | ||||||
|             output = output.cpu() |  | ||||||
|             # Get the model prediction (greedy) |  | ||||||
|             output = output.argmax(dim=-1).squeeze() |  | ||||||
|             # Add the prediction to prompt |  | ||||||
|             prompt += self.text.itos[output[-1]] |  | ||||||
|             # Add the prediction for logging |  | ||||||
|             log += [(self.text.itos[output[-1]], Text.value)] |  | ||||||
|  |  | ||||||
|         logger.log(log) |  | ||||||
|  |  | ||||||
|     def step(self, batch: Any, batch_idx: BatchIndex): |  | ||||||
|         """ |  | ||||||
|         This method is called for each batch |  | ||||||
|         """ |  | ||||||
|         self.model.train(self.mode.is_train) |  | ||||||
|  |  | ||||||
|         # Get data and target labels |  | ||||||
|         data, target = batch[0].to(self.device), batch[1].to(self.device) |  | ||||||
|  |  | ||||||
|         if self.mode.is_train: |  | ||||||
|             tracker.add_global_step(data.shape[0] * data.shape[1]) |  | ||||||
|  |  | ||||||
|         # Run the model |  | ||||||
|         output, state = self.model(data) |  | ||||||
|  |  | ||||||
|         # Calculate loss |  | ||||||
|         loss = self.loss_func(output, target) |  | ||||||
|         # Calculate accuracy |  | ||||||
|         self.accuracy(output, target) |  | ||||||
|  |  | ||||||
|         # Log the loss |  | ||||||
|         tracker.add("loss.", loss) |  | ||||||
|  |  | ||||||
|         #  If we are in training mode, calculate the gradients |  | ||||||
|         if self.mode.is_train: |  | ||||||
|             loss.backward() |  | ||||||
|             self.optimizer.step() |  | ||||||
|             if batch_idx.is_last: |  | ||||||
|                 tracker.add('model', self.model) |  | ||||||
|             self.optimizer.zero_grad() |  | ||||||
|  |  | ||||||
|         tracker.save() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def character_tokenizer(x: str): |  | ||||||
|     return list(x) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @option(Configs.tokenizer) |  | ||||||
| def character(): |  | ||||||
|     """ |  | ||||||
|     Character level tokenizer |  | ||||||
|     """ |  | ||||||
|     return character_tokenizer |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @option(Configs.text) |  | ||||||
| def tiny_shakespeare(c: Configs): |  | ||||||
|     return TextFileDataset( |  | ||||||
|         lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer, |  | ||||||
|         url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @option(Configs.model) | @option(Configs.model) | ||||||
| @ -191,7 +50,7 @@ def autoregressive_model(c: Configs): | |||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|     # Create experiment |     # Create experiment | ||||||
|     experiment.create(name="knn_lm", comment='') |     experiment.create(name="hyper_lstm", comment='') | ||||||
|     # Create configs |     # Create configs | ||||||
|     conf = Configs() |     conf = Configs() | ||||||
|     # Load configurations |     # Load configurations | ||||||
| @ -200,6 +59,12 @@ def main(): | |||||||
|                        {'tokenizer': 'character', |                        {'tokenizer': 'character', | ||||||
|                         'text': 'tiny_shakespeare', |                         'text': 'tiny_shakespeare', | ||||||
|                         'optimizer.learning_rate': 2.5e-4, |                         'optimizer.learning_rate': 2.5e-4, | ||||||
|  |                         'optimizer.optimizer': 'Adam', | ||||||
|  |                         'prompt': 'It is', | ||||||
|  |                         'prompt_separator': '', | ||||||
|  |  | ||||||
|  |                         'train_loader': 'shuffled_train_loader', | ||||||
|  |                         'valid_loader': 'shuffled_valid_loader', | ||||||
|  |  | ||||||
|                         'seq_len': 512, |                         'seq_len': 512, | ||||||
|                         'epochs': 128, |                         'epochs': 128, | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri