mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
auto regression common exp
This commit is contained in:
@ -44,13 +44,14 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
|
|
||||||
is_save_models = True
|
is_save_models = True
|
||||||
|
|
||||||
loss_func: CrossEntropyLoss()
|
loss_func = CrossEntropyLoss()
|
||||||
|
accuracy = Accuracy()
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
tracker.set_queue("loss.*", 20, True)
|
tracker.set_queue("loss.*", 20, True)
|
||||||
tracker.set_scalar("accuracy.*", True)
|
tracker.set_scalar("accuracy.*", True)
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
hook_model_outputs(self.mode, self.model, 'model')
|
||||||
self.state_modules = [Accuracy()]
|
self.state_modules = [self.accuracy]
|
||||||
|
|
||||||
def step(self, batch: any, batch_idx: BatchIndex):
|
def step(self, batch: any, batch_idx: BatchIndex):
|
||||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||||
@ -62,8 +63,8 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
output, *_ = self.model(data)
|
output, *_ = self.model(data)
|
||||||
|
|
||||||
loss = self.loss_func(output, target)
|
loss = self.loss_func(output, target)
|
||||||
self.accuracy_func(output, target)
|
self.accuracy(output, target)
|
||||||
self.accuracy_func.track()
|
self.accuracy.track()
|
||||||
tracker.add("loss.", loss)
|
tracker.add("loss.", loss)
|
||||||
|
|
||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
@ -88,7 +89,7 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
data = self.text.text_to_i(prompt).unsqueeze(-1)
|
data = self.text.text_to_i(prompt).unsqueeze(-1)
|
||||||
data = data.to(self.device)
|
data = data.to(self.device)
|
||||||
# Get the model output
|
# Get the model output
|
||||||
output = self.model(data)
|
output, *_ = self.model(data)
|
||||||
# Get the model prediction (greedy)
|
# Get the model prediction (greedy)
|
||||||
output = output.argmax(dim=-1).squeeze()
|
output = output.argmax(dim=-1).squeeze()
|
||||||
# Add the prediction to prompt
|
# Add the prediction to prompt
|
||||||
|
|||||||
@ -1,18 +1,11 @@
|
|||||||
from typing import Callable, Any
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from labml import lab, experiment, monit, tracker, logger
|
from labml import experiment
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml.logger import Text
|
|
||||||
from labml.utils.pytorch import get_modules
|
from labml.utils.pytorch import get_modules
|
||||||
from labml_helpers.datasets.text import TextDataset, TextFileDataset, SequentialUnBatchedDataset
|
|
||||||
from labml_helpers.metrics.accuracy import Accuracy
|
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.optimizer import OptimizerConfigs
|
|
||||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
|
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
|
||||||
|
|
||||||
|
|
||||||
@ -36,28 +29,7 @@ class AutoregressiveModel(Module):
|
|||||||
return self.generator(res), state
|
return self.generator(res), state
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(Module):
|
class Configs(NLPAutoRegressionConfigs):
|
||||||
"""
|
|
||||||
Cross entropy loss
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.loss = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def __call__(self, outputs, targets):
|
|
||||||
return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))
|
|
||||||
|
|
||||||
|
|
||||||
def transpose_batch(batch):
|
|
||||||
transposed_data = list(zip(*batch))
|
|
||||||
src = torch.stack(transposed_data[0], 1)
|
|
||||||
tgt = torch.stack(transposed_data[1], 1)
|
|
||||||
|
|
||||||
return src, tgt
|
|
||||||
|
|
||||||
|
|
||||||
class Configs(SimpleTrainValidConfigs):
|
|
||||||
"""
|
"""
|
||||||
## Configurations
|
## Configurations
|
||||||
|
|
||||||
@ -65,119 +37,6 @@ class Configs(SimpleTrainValidConfigs):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model: AutoregressiveModel
|
model: AutoregressiveModel
|
||||||
text: TextDataset
|
|
||||||
batch_size: int = 20
|
|
||||||
seq_len: int = 512
|
|
||||||
n_tokens: int
|
|
||||||
tokenizer: Callable = 'character'
|
|
||||||
|
|
||||||
is_save_models = True
|
|
||||||
|
|
||||||
optimizer: torch.optim.Adam = 'transformer_optimizer'
|
|
||||||
|
|
||||||
accuracy = Accuracy()
|
|
||||||
loss_func = CrossEntropyLoss()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
# Create a configurable optimizer.
|
|
||||||
# Parameters like learning rate can be changed by passing a dictionary when starting the experiment.
|
|
||||||
optimizer = OptimizerConfigs()
|
|
||||||
optimizer.parameters = self.model.parameters()
|
|
||||||
optimizer.optimizer = 'Adam'
|
|
||||||
self.optimizer = optimizer
|
|
||||||
|
|
||||||
# Create a sequential data loader for training
|
|
||||||
self.train_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.train,
|
|
||||||
dataset=self.text,
|
|
||||||
seq_len=self.seq_len),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
collate_fn=transpose_batch,
|
|
||||||
shuffle=True)
|
|
||||||
|
|
||||||
# Create a sequential data loader for validation
|
|
||||||
self.valid_loader = DataLoader(SequentialUnBatchedDataset(text=self.text.valid,
|
|
||||||
dataset=self.text,
|
|
||||||
seq_len=self.seq_len),
|
|
||||||
batch_size=self.batch_size,
|
|
||||||
collate_fn=transpose_batch,
|
|
||||||
shuffle=True)
|
|
||||||
|
|
||||||
self.state_modules = [self.accuracy]
|
|
||||||
|
|
||||||
def sample(self):
|
|
||||||
"""
|
|
||||||
Sampling function to generate samples periodically while training
|
|
||||||
"""
|
|
||||||
prompt = 'It is'
|
|
||||||
log = [(prompt, Text.subtle)]
|
|
||||||
# Sample 25 tokens
|
|
||||||
for i in monit.iterate('Sample', 25):
|
|
||||||
# Tokenize the prompt
|
|
||||||
data = self.text.text_to_i(prompt).unsqueeze(-1)
|
|
||||||
data = data.to(self.device)
|
|
||||||
# Get the model output
|
|
||||||
output, state = self.model(data)
|
|
||||||
output = output.cpu()
|
|
||||||
# Get the model prediction (greedy)
|
|
||||||
output = output.argmax(dim=-1).squeeze()
|
|
||||||
# Add the prediction to prompt
|
|
||||||
prompt += self.text.itos[output[-1]]
|
|
||||||
# Add the prediction for logging
|
|
||||||
log += [(self.text.itos[output[-1]], Text.value)]
|
|
||||||
|
|
||||||
logger.log(log)
|
|
||||||
|
|
||||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
|
||||||
"""
|
|
||||||
This method is called for each batch
|
|
||||||
"""
|
|
||||||
self.model.train(self.mode.is_train)
|
|
||||||
|
|
||||||
# Get data and target labels
|
|
||||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
|
||||||
|
|
||||||
if self.mode.is_train:
|
|
||||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
|
||||||
|
|
||||||
# Run the model
|
|
||||||
output, state = self.model(data)
|
|
||||||
|
|
||||||
# Calculate loss
|
|
||||||
loss = self.loss_func(output, target)
|
|
||||||
# Calculate accuracy
|
|
||||||
self.accuracy(output, target)
|
|
||||||
|
|
||||||
# Log the loss
|
|
||||||
tracker.add("loss.", loss)
|
|
||||||
|
|
||||||
# If we are in training mode, calculate the gradients
|
|
||||||
if self.mode.is_train:
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
if batch_idx.is_last:
|
|
||||||
tracker.add('model', self.model)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
tracker.save()
|
|
||||||
|
|
||||||
|
|
||||||
def character_tokenizer(x: str):
|
|
||||||
return list(x)
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.tokenizer)
|
|
||||||
def character():
|
|
||||||
"""
|
|
||||||
Character level tokenizer
|
|
||||||
"""
|
|
||||||
return character_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.text)
|
|
||||||
def tiny_shakespeare(c: Configs):
|
|
||||||
return TextFileDataset(
|
|
||||||
lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
|
|
||||||
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.model)
|
@option(Configs.model)
|
||||||
@ -191,7 +50,7 @@ def autoregressive_model(c: Configs):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create experiment
|
# Create experiment
|
||||||
experiment.create(name="knn_lm", comment='')
|
experiment.create(name="hyper_lstm", comment='')
|
||||||
# Create configs
|
# Create configs
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
# Load configurations
|
# Load configurations
|
||||||
@ -200,6 +59,12 @@ def main():
|
|||||||
{'tokenizer': 'character',
|
{'tokenizer': 'character',
|
||||||
'text': 'tiny_shakespeare',
|
'text': 'tiny_shakespeare',
|
||||||
'optimizer.learning_rate': 2.5e-4,
|
'optimizer.learning_rate': 2.5e-4,
|
||||||
|
'optimizer.optimizer': 'Adam',
|
||||||
|
'prompt': 'It is',
|
||||||
|
'prompt_separator': '',
|
||||||
|
|
||||||
|
'train_loader': 'shuffled_train_loader',
|
||||||
|
'valid_loader': 'shuffled_valid_loader',
|
||||||
|
|
||||||
'seq_len': 512,
|
'seq_len': 512,
|
||||||
'epochs': 128,
|
'epochs': 128,
|
||||||
|
|||||||
Reference in New Issue
Block a user