auto regression common exp

This commit is contained in:
Varuna Jayasiri
2020-12-27 07:35:35 +05:30
parent 4fe8392a8d
commit 52d5b5fbf6

View File

@ -9,21 +9,13 @@ summary: This is training code with notes for a basic auto-regressive transforme
This trains a simple [transformer](../../) model for auto-regression.
"""
from typing import Callable, Any
import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from labml import lab, experiment, monit, tracker, logger
from labml import experiment
from labml.configs import option
from labml.logger import Text
from labml.utils.pytorch import get_modules
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import Encoder, Generator, TransformerConfigs
from labml_nn.transformers.utils import subsequent_mask
@ -64,23 +56,10 @@ class AutoregressiveModel(Module):
# Embed the tokens (`src`) and run it through the the transformer
res = self.encoder(self.src_embed(src), self.src_mask)
# Generate logits of the next token
return self.generator(res)
return self.generator(res), None
class CrossEntropyLoss(Module):
"""
Cross entropy loss
"""
def __init__(self):
super().__init__()
self.loss = nn.CrossEntropyLoss()
def __call__(self, outputs, targets):
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
class Configs(SimpleTrainValidConfigs):
class Configs(NLPAutoRegressionConfigs):
"""
## Configurations
@ -89,141 +68,8 @@ class Configs(SimpleTrainValidConfigs):
transformer: TransformerConfigs
model: AutoregressiveModel
text: TextDataset
batch_size: int = 20
seq_len: int = 32
n_tokens: int
tokenizer: Callable = 'character'
is_save_models = True
prompt: str
prompt_separator: str
is_save_ff_input = False
optimizer: torch.optim.Adam = 'transformer_optimizer'
accuracy = Accuracy()
loss_func = CrossEntropyLoss()
def init(self):
# Create a configurable optimizer.
# Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
optimizer = OptimizerConfigs()
optimizer.parameters = self.model.parameters()
optimizer.d_model = self.transformer.d_model
optimizer.optimizer = 'Noam'
self.optimizer = optimizer
# Create a sequential data loader for training
self.train_loader = SequentialDataLoader(text=self.text.train,
dataset=self.text,
batch_size=self.batch_size,
seq_len=self.seq_len)
# Create a sequential data loader for validation
self.valid_loader = SequentialDataLoader(text=self.text.valid,
dataset=self.text,
batch_size=self.batch_size,
seq_len=self.seq_len)
self.state_modules = [self.accuracy]
def sample(self):
"""
Sampling function to generate samples periodically while training
"""
prompt = self.prompt
log = [(prompt, Text.subtle)]
# Sample 25 tokens
for i in monit.iterate('Sample', 25):
# Tokenize the prompt
data = self.text.text_to_i(prompt).unsqueeze(-1)
data = data.to(self.device)
# Get the model output
output = self.model(data)
# Get the model prediction (greedy)
output = output.argmax(dim=-1).squeeze()
# Add the prediction to prompt
prompt += self.prompt_separator + self.text.itos[output[-1]]
# Add the prediction for logging
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
logger.log(log)
def step(self, batch: Any, batch_idx: BatchIndex):
"""
This method is called for each batch
"""
self.model.train(self.mode.is_train)
# Get data and target labels
data, target = batch[0].to(self.model.device), batch[1].to(self.model.device)
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Run the model
output = self.model(data)
# Calculate loss
loss = self.loss_func(output, target)
# Calculate accuracy
self.accuracy(output, target)
# Log the loss
tracker.add("loss.", loss)
# If we are in training mode, calculate the gradients
if self.mode.is_train:
loss.backward()
self.optimizer.step()
if batch_idx.is_last:
tracker.add('model', self.model)
self.optimizer.zero_grad()
tracker.save()
@option(Configs.tokenizer)
def basic_english():
"""
Basic english tokenizer
We use character level tokenizer in this experiment.
You can switch by setting,
```
'tokenizer': 'basic_english',
```
as the configurations dictionary when starting the experiment.
"""
return get_tokenizer('basic_english')
def character_tokenizer(x: str):
return list(x)
@option(Configs.tokenizer)
def character():
"""
Character level tokenizer
"""
return character_tokenizer
@option(Configs.text)
def tiny_shakespeare(c: Configs):
"""
Initialize/load tiny shakespeare dataset
This dataset is from Andrej Karpathy's [char-rnn](https://github.com/karpathy/char-rnn) project.
"""
return TextFileDataset(
lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
@option(Configs.model)
@ -256,7 +102,7 @@ def transformer_c(c: Configs):
def main():
# Create experiment
experiment.create(name="knn_lm", comment='', writers={'tensorboard', 'sqlite', 'screen'})
experiment.create(name="knn_lm")
# Create configs
conf = Configs()
# Load configurations
@ -267,6 +113,10 @@ def main():
'prompt': 'It is ',
'text': 'tiny_shakespeare',
'optimizer.optimizer': 'Noam',
'optimizer.learning_rate': 1.,
'optimizer.d_model': 256,
'seq_len': 1024,
'epochs': 128,
'batch_size': 6,