mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 12:01:45 +08:00
auto regression common exp
This commit is contained in:
@ -9,21 +9,13 @@ summary: This is training code with notes for a basic auto-regressive transforme
|
||||
This trains a simple [transformer](../../) model for auto-regression.
|
||||
"""
|
||||
|
||||
from typing import Callable, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchtext.data.utils import get_tokenizer
|
||||
|
||||
from labml import lab, experiment, monit, tracker, logger
|
||||
from labml import experiment
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml.utils.pytorch import get_modules
|
||||
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
from labml_nn.transformers import Encoder, Generator, TransformerConfigs
|
||||
from labml_nn.transformers.utils import subsequent_mask
|
||||
|
||||
@ -64,23 +56,10 @@ class AutoregressiveModel(Module):
|
||||
# Embed the tokens (`src`) and run it through the the transformer
|
||||
res = self.encoder(self.src_embed(src), self.src_mask)
|
||||
# Generate logits of the next token
|
||||
return self.generator(res)
|
||||
return self.generator(res), None
|
||||
|
||||
|
||||
class CrossEntropyLoss(Module):
|
||||
"""
|
||||
Cross entropy loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def __call__(self, outputs, targets):
|
||||
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
|
||||
|
||||
|
||||
class Configs(SimpleTrainValidConfigs):
|
||||
class Configs(NLPAutoRegressionConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
@ -89,141 +68,8 @@ class Configs(SimpleTrainValidConfigs):
|
||||
|
||||
transformer: TransformerConfigs
|
||||
model: AutoregressiveModel
|
||||
text: TextDataset
|
||||
batch_size: int = 20
|
||||
seq_len: int = 32
|
||||
n_tokens: int
|
||||
tokenizer: Callable = 'character'
|
||||
|
||||
is_save_models = True
|
||||
prompt: str
|
||||
prompt_separator: str
|
||||
|
||||
is_save_ff_input = False
|
||||
optimizer: torch.optim.Adam = 'transformer_optimizer'
|
||||
|
||||
accuracy = Accuracy()
|
||||
loss_func = CrossEntropyLoss()
|
||||
|
||||
def init(self):
|
||||
# Create a configurable optimizer.
|
||||
# Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
|
||||
optimizer = OptimizerConfigs()
|
||||
optimizer.parameters = self.model.parameters()
|
||||
optimizer.d_model = self.transformer.d_model
|
||||
optimizer.optimizer = 'Noam'
|
||||
self.optimizer = optimizer
|
||||
|
||||
# Create a sequential data loader for training
|
||||
self.train_loader = SequentialDataLoader(text=self.text.train,
|
||||
dataset=self.text,
|
||||
batch_size=self.batch_size,
|
||||
seq_len=self.seq_len)
|
||||
|
||||
# Create a sequential data loader for validation
|
||||
self.valid_loader = SequentialDataLoader(text=self.text.valid,
|
||||
dataset=self.text,
|
||||
batch_size=self.batch_size,
|
||||
seq_len=self.seq_len)
|
||||
|
||||
self.state_modules = [self.accuracy]
|
||||
|
||||
def sample(self):
|
||||
"""
|
||||
Sampling function to generate samples periodically while training
|
||||
"""
|
||||
prompt = self.prompt
|
||||
log = [(prompt, Text.subtle)]
|
||||
# Sample 25 tokens
|
||||
for i in monit.iterate('Sample', 25):
|
||||
# Tokenize the prompt
|
||||
data = self.text.text_to_i(prompt).unsqueeze(-1)
|
||||
data = data.to(self.device)
|
||||
# Get the model output
|
||||
output = self.model(data)
|
||||
# Get the model prediction (greedy)
|
||||
output = output.argmax(dim=-1).squeeze()
|
||||
# Add the prediction to prompt
|
||||
prompt += self.prompt_separator + self.text.itos[output[-1]]
|
||||
# Add the prediction for logging
|
||||
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
|
||||
|
||||
logger.log(log)
|
||||
|
||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||
"""
|
||||
This method is called for each batch
|
||||
"""
|
||||
self.model.train(self.mode.is_train)
|
||||
|
||||
# Get data and target labels
|
||||
data, target = batch[0].to(self.model.device), batch[1].to(self.model.device)
|
||||
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||
|
||||
# Run the model
|
||||
output = self.model(data)
|
||||
|
||||
# Calculate loss
|
||||
loss = self.loss_func(output, target)
|
||||
# Calculate accuracy
|
||||
self.accuracy(output, target)
|
||||
|
||||
# Log the loss
|
||||
tracker.add("loss.", loss)
|
||||
|
||||
# If we are in training mode, calculate the gradients
|
||||
if self.mode.is_train:
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
if batch_idx.is_last:
|
||||
tracker.add('model', self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
tracker.save()
|
||||
|
||||
|
||||
@option(Configs.tokenizer)
|
||||
def basic_english():
|
||||
"""
|
||||
Basic english tokenizer
|
||||
|
||||
We use character level tokenizer in this experiment.
|
||||
You can switch by setting,
|
||||
|
||||
```
|
||||
'tokenizer': 'basic_english',
|
||||
```
|
||||
|
||||
as the configurations dictionary when starting the experiment.
|
||||
|
||||
"""
|
||||
return get_tokenizer('basic_english')
|
||||
|
||||
|
||||
def character_tokenizer(x: str):
|
||||
return list(x)
|
||||
|
||||
|
||||
@option(Configs.tokenizer)
|
||||
def character():
|
||||
"""
|
||||
Character level tokenizer
|
||||
"""
|
||||
return character_tokenizer
|
||||
|
||||
|
||||
@option(Configs.text)
|
||||
def tiny_shakespeare(c: Configs):
|
||||
"""
|
||||
Initialize/load tiny shakespeare dataset
|
||||
|
||||
This dataset is from Andrej Karpathy's [char-rnn](https://github.com/karpathy/char-rnn) project.
|
||||
"""
|
||||
return TextFileDataset(
|
||||
lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
|
||||
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
|
||||
|
||||
|
||||
@option(Configs.model)
|
||||
@ -256,7 +102,7 @@ def transformer_c(c: Configs):
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name="knn_lm", comment='', writers={'tensorboard', 'sqlite', 'screen'})
|
||||
experiment.create(name="knn_lm")
|
||||
# Create configs
|
||||
conf = Configs()
|
||||
# Load configurations
|
||||
@ -267,6 +113,10 @@ def main():
|
||||
'prompt': 'It is ',
|
||||
'text': 'tiny_shakespeare',
|
||||
|
||||
'optimizer.optimizer': 'Noam',
|
||||
'optimizer.learning_rate': 1.,
|
||||
'optimizer.d_model': 256,
|
||||
|
||||
'seq_len': 1024,
|
||||
'epochs': 128,
|
||||
'batch_size': 6,
|
||||
|
||||
Reference in New Issue
Block a user