mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-02 04:37:46 +08:00
auto regression common exp
This commit is contained in:
@ -9,21 +9,13 @@ summary: This is training code with notes for a basic auto-regressive transforme
|
|||||||
This trains a simple [transformer](../../) model for auto-regression.
|
This trains a simple [transformer](../../) model for auto-regression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Callable, Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from labml import experiment
|
||||||
from torchtext.data.utils import get_tokenizer
|
|
||||||
|
|
||||||
from labml import lab, experiment, monit, tracker, logger
|
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml.logger import Text
|
|
||||||
from labml.utils.pytorch import get_modules
|
from labml.utils.pytorch import get_modules
|
||||||
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, TextFileDataset
|
|
||||||
from labml_helpers.metrics.accuracy import Accuracy
|
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.optimizer import OptimizerConfigs
|
|
||||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
from labml_nn.transformers import Encoder, Generator, TransformerConfigs
|
from labml_nn.transformers import Encoder, Generator, TransformerConfigs
|
||||||
from labml_nn.transformers.utils import subsequent_mask
|
from labml_nn.transformers.utils import subsequent_mask
|
||||||
|
|
||||||
@ -64,23 +56,10 @@ class AutoregressiveModel(Module):
|
|||||||
# Embed the tokens (`src`) and run it through the the transformer
|
# Embed the tokens (`src`) and run it through the the transformer
|
||||||
res = self.encoder(self.src_embed(src), self.src_mask)
|
res = self.encoder(self.src_embed(src), self.src_mask)
|
||||||
# Generate logits of the next token
|
# Generate logits of the next token
|
||||||
return self.generator(res)
|
return self.generator(res), None
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(Module):
|
class Configs(NLPAutoRegressionConfigs):
|
||||||
"""
|
|
||||||
Cross entropy loss
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.loss = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def __call__(self, outputs, targets):
|
|
||||||
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
|
|
||||||
|
|
||||||
|
|
||||||
class Configs(SimpleTrainValidConfigs):
|
|
||||||
"""
|
"""
|
||||||
## Configurations
|
## Configurations
|
||||||
|
|
||||||
@ -89,141 +68,8 @@ class Configs(SimpleTrainValidConfigs):
|
|||||||
|
|
||||||
transformer: TransformerConfigs
|
transformer: TransformerConfigs
|
||||||
model: AutoregressiveModel
|
model: AutoregressiveModel
|
||||||
text: TextDataset
|
|
||||||
batch_size: int = 20
|
|
||||||
seq_len: int = 32
|
|
||||||
n_tokens: int
|
|
||||||
tokenizer: Callable = 'character'
|
|
||||||
|
|
||||||
is_save_models = True
|
|
||||||
prompt: str
|
|
||||||
prompt_separator: str
|
|
||||||
|
|
||||||
is_save_ff_input = False
|
is_save_ff_input = False
|
||||||
optimizer: torch.optim.Adam = 'transformer_optimizer'
|
|
||||||
|
|
||||||
accuracy = Accuracy()
|
|
||||||
loss_func = CrossEntropyLoss()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
# Create a configurable optimizer.
|
|
||||||
# Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
|
|
||||||
optimizer = OptimizerConfigs()
|
|
||||||
optimizer.parameters = self.model.parameters()
|
|
||||||
optimizer.d_model = self.transformer.d_model
|
|
||||||
optimizer.optimizer = 'Noam'
|
|
||||||
self.optimizer = optimizer
|
|
||||||
|
|
||||||
# Create a sequential data loader for training
|
|
||||||
self.train_loader = SequentialDataLoader(text=self.text.train,
|
|
||||||
dataset=self.text,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
seq_len=self.seq_len)
|
|
||||||
|
|
||||||
# Create a sequential data loader for validation
|
|
||||||
self.valid_loader = SequentialDataLoader(text=self.text.valid,
|
|
||||||
dataset=self.text,
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
seq_len=self.seq_len)
|
|
||||||
|
|
||||||
self.state_modules = [self.accuracy]
|
|
||||||
|
|
||||||
def sample(self):
|
|
||||||
"""
|
|
||||||
Sampling function to generate samples periodically while training
|
|
||||||
"""
|
|
||||||
prompt = self.prompt
|
|
||||||
log = [(prompt, Text.subtle)]
|
|
||||||
# Sample 25 tokens
|
|
||||||
for i in monit.iterate('Sample', 25):
|
|
||||||
# Tokenize the prompt
|
|
||||||
data = self.text.text_to_i(prompt).unsqueeze(-1)
|
|
||||||
data = data.to(self.device)
|
|
||||||
# Get the model output
|
|
||||||
output = self.model(data)
|
|
||||||
# Get the model prediction (greedy)
|
|
||||||
output = output.argmax(dim=-1).squeeze()
|
|
||||||
# Add the prediction to prompt
|
|
||||||
prompt += self.prompt_separator + self.text.itos[output[-1]]
|
|
||||||
# Add the prediction for logging
|
|
||||||
log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
|
|
||||||
|
|
||||||
logger.log(log)
|
|
||||||
|
|
||||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
|
||||||
"""
|
|
||||||
This method is called for each batch
|
|
||||||
"""
|
|
||||||
self.model.train(self.mode.is_train)
|
|
||||||
|
|
||||||
# Get data and target labels
|
|
||||||
data, target = batch[0].to(self.model.device), batch[1].to(self.model.device)
|
|
||||||
|
|
||||||
if self.mode.is_train:
|
|
||||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
|
||||||
|
|
||||||
# Run the model
|
|
||||||
output = self.model(data)
|
|
||||||
|
|
||||||
# Calculate loss
|
|
||||||
loss = self.loss_func(output, target)
|
|
||||||
# Calculate accuracy
|
|
||||||
self.accuracy(output, target)
|
|
||||||
|
|
||||||
# Log the loss
|
|
||||||
tracker.add("loss.", loss)
|
|
||||||
|
|
||||||
# If we are in training mode, calculate the gradients
|
|
||||||
if self.mode.is_train:
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
if batch_idx.is_last:
|
|
||||||
tracker.add('model', self.model)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
tracker.save()
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.tokenizer)
|
|
||||||
def basic_english():
|
|
||||||
"""
|
|
||||||
Basic english tokenizer
|
|
||||||
|
|
||||||
We use character level tokenizer in this experiment.
|
|
||||||
You can switch by setting,
|
|
||||||
|
|
||||||
```
|
|
||||||
'tokenizer': 'basic_english',
|
|
||||||
```
|
|
||||||
|
|
||||||
as the configurations dictionary when starting the experiment.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return get_tokenizer('basic_english')
|
|
||||||
|
|
||||||
|
|
||||||
def character_tokenizer(x: str):
|
|
||||||
return list(x)
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.tokenizer)
|
|
||||||
def character():
|
|
||||||
"""
|
|
||||||
Character level tokenizer
|
|
||||||
"""
|
|
||||||
return character_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.text)
|
|
||||||
def tiny_shakespeare(c: Configs):
|
|
||||||
"""
|
|
||||||
Initialize/load tiny shakespeare dataset
|
|
||||||
|
|
||||||
This dataset is from Andrej Karpathy's [char-rnn](https://github.com/karpathy/char-rnn) project.
|
|
||||||
"""
|
|
||||||
return TextFileDataset(
|
|
||||||
lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
|
|
||||||
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.model)
|
@option(Configs.model)
|
||||||
@ -256,7 +102,7 @@ def transformer_c(c: Configs):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create experiment
|
# Create experiment
|
||||||
experiment.create(name="knn_lm", comment='', writers={'tensorboard', 'sqlite', 'screen'})
|
experiment.create(name="knn_lm")
|
||||||
# Create configs
|
# Create configs
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
# Load configurations
|
# Load configurations
|
||||||
@ -267,6 +113,10 @@ def main():
|
|||||||
'prompt': 'It is ',
|
'prompt': 'It is ',
|
||||||
'text': 'tiny_shakespeare',
|
'text': 'tiny_shakespeare',
|
||||||
|
|
||||||
|
'optimizer.optimizer': 'Noam',
|
||||||
|
'optimizer.learning_rate': 1.,
|
||||||
|
'optimizer.d_model': 256,
|
||||||
|
|
||||||
'seq_len': 1024,
|
'seq_len': 1024,
|
||||||
'epochs': 128,
|
'epochs': 128,
|
||||||
'batch_size': 6,
|
'batch_size': 6,
|
||||||
|
|||||||
Reference in New Issue
Block a user