diff --git a/labml_nn/transformers/knn/train_model.py b/labml_nn/transformers/knn/train_model.py index 2e015415..94b04a16 100644 --- a/labml_nn/transformers/knn/train_model.py +++ b/labml_nn/transformers/knn/train_model.py @@ -9,21 +9,13 @@ summary: This is training code with notes for a basic auto-regressive transforme This trains a simple [transformer](../../) model for auto-regression. """ -from typing import Callable, Any - import torch -import torch.nn as nn -from torchtext.data.utils import get_tokenizer - -from labml import lab, experiment, monit, tracker, logger +from labml import experiment from labml.configs import option -from labml.logger import Text from labml.utils.pytorch import get_modules -from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset -from labml_helpers.metrics.accuracy import Accuracy from labml_helpers.module import Module -from labml_helpers.optimizer import OptimizerConfigs -from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex + +from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.transformers import Encoder, Generator, TransformerConfigs from labml_nn.transformers.utils import subsequent_mask @@ -64,23 +56,10 @@ class AutoregressiveModel(Module): # Embed the tokens (`src`) and run it through the the transformer res = self.encoder(self.src_embed(src), self.src_mask) # Generate logits of the next token - return self.generator(res) + return self.generator(res), None -class CrossEntropyLoss(Module): - """ - Cross entropy loss - """ - - def __init__(self): - super().__init__() - self.loss = nn.CrossEntropyLoss() - - def __call__(self, outputs, targets): - return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1)) - - -class Configs(SimpleTrainValidConfigs): +class Configs(NLPAutoRegressionConfigs): """ ## Configurations @@ -89,141 +68,8 @@ class Configs(SimpleTrainValidConfigs): transformer: TransformerConfigs model: AutoregressiveModel - text: TextDataset - batch_size: int = 20 - seq_len: int = 32 - n_tokens: int - tokenizer: Callable = 'character' - - is_save_models = True - prompt: str - prompt_separator: str is_save_ff_input = False - optimizer: torch.optim.Adam = 'transformer_optimizer' - - accuracy = Accuracy() - loss_func = CrossEntropyLoss() - - def init(self): - # Create a configurable optimizer. - # Parameters like learning rate can be changed by passing a dictionary when starting the experiment. - optimizer = OptimizerConfigs() - optimizer.parameters = self.model.parameters() - optimizer.d_model = self.transformer.d_model - optimizer.optimizer = 'Noam' - self.optimizer = optimizer - - # Create a sequential data loader for training - self.train_loader = SequentialDataLoader(text=self.text.train, - dataset=self.text, - batch_size=self.batch_size, - seq_len=self.seq_len) - - # Create a sequential data loader for validation - self.valid_loader = SequentialDataLoader(text=self.text.valid, - dataset=self.text, - batch_size=self.batch_size, - seq_len=self.seq_len) - - self.state_modules = [self.accuracy] - - def sample(self): - """ - Sampling function to generate samples periodically while training - """ - prompt = self.prompt - log = [(prompt, Text.subtle)] - # Sample 25 tokens - for i in monit.iterate('Sample', 25): - # Tokenize the prompt - data = self.text.text_to_i(prompt).unsqueeze(-1) - data = data.to(self.device) - # Get the model output - output = self.model(data) - # Get the model prediction (greedy) - output = output.argmax(dim=-1).squeeze() - # Add the prediction to prompt - prompt += self.prompt_separator + self.text.itos[output[-1]] - # Add the prediction for logging - log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)] - - logger.log(log) - - def step(self, batch: Any, batch_idx: BatchIndex): - """ - This method is called for each batch - """ - self.model.train(self.mode.is_train) - - # Get data and target labels - data, target = batch[0].to(self.model.device), batch[1].to(self.model.device) - - if self.mode.is_train: - tracker.add_global_step(data.shape[0] * data.shape[1]) - - # Run the model - output = self.model(data) - - # Calculate loss - loss = self.loss_func(output, target) - # Calculate accuracy - self.accuracy(output, target) - - # Log the loss - tracker.add("loss.", loss) - - # If we are in training mode, calculate the gradients - if self.mode.is_train: - loss.backward() - self.optimizer.step() - if batch_idx.is_last: - tracker.add('model', self.model) - self.optimizer.zero_grad() - - tracker.save() - - -@option(Configs.tokenizer) -def basic_english(): - """ - Basic english tokenizer - - We use character level tokenizer in this experiment. - You can switch by setting, - - ``` - 'tokenizer': 'basic_english', - ``` - - as the configurations dictionary when starting the experiment. - - """ - return get_tokenizer('basic_english') - - -def character_tokenizer(x: str): - return list(x) - - -@option(Configs.tokenizer) -def character(): - """ - Character level tokenizer - """ - return character_tokenizer - - -@option(Configs.text) -def tiny_shakespeare(c: Configs): - """ - Initialize/load tiny shakespeare dataset - - This dataset is from Andrej Karpathy's [char-rnn](https://github.com/karpathy/char-rnn) project. - """ - return TextFileDataset( - lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer, - url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') @option(Configs.model) @@ -256,7 +102,7 @@ def transformer_c(c: Configs): def main(): # Create experiment - experiment.create(name="knn_lm", comment='', writers={'tensorboard', 'sqlite', 'screen'}) + experiment.create(name="knn_lm") # Create configs conf = Configs() # Load configurations @@ -267,6 +113,10 @@ def main(): 'prompt': 'It is ', 'text': 'tiny_shakespeare', + 'optimizer.optimizer': 'Noam', + 'optimizer.learning_rate': 1., + 'optimizer.d_model': 256, + 'seq_len': 1024, 'epochs': 128, 'batch_size': 6,