mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
290 lines
8.6 KiB
Python
290 lines
8.6 KiB
Python
"""
|
|
---
|
|
title: NLP classification trainer
|
|
summary: >
|
|
This is a reusable trainer for classification tasks
|
|
---
|
|
|
|
# NLP model trainer for classification
|
|
"""
|
|
|
|
from collections import Counter
|
|
from typing import Callable
|
|
|
|
import torchtext
|
|
import torchtext.vocab
|
|
from torchtext.vocab import Vocab
|
|
|
|
import torch
|
|
from labml import lab, tracker, monit
|
|
from labml.configs import option
|
|
from labml_nn.helpers.device import DeviceConfigs
|
|
from labml_nn.helpers.metrics import Accuracy
|
|
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
|
from labml_nn.optimizers.configs import OptimizerConfigs
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
class NLPClassificationConfigs(TrainValidConfigs):
|
|
"""
|
|
<a id="NLPClassificationConfigs"></a>
|
|
|
|
## Trainer configurations
|
|
|
|
This has the basic configurations for NLP classification task training.
|
|
All the properties are configurable.
|
|
"""
|
|
|
|
# Optimizer
|
|
optimizer: torch.optim.Adam
|
|
# Training device
|
|
device: torch.device = DeviceConfigs()
|
|
|
|
# Autoregressive model
|
|
model: nn.Module
|
|
# Batch size
|
|
batch_size: int = 16
|
|
# Length of the sequence, or context size
|
|
seq_len: int = 512
|
|
# Vocabulary
|
|
vocab: Vocab = 'ag_news'
|
|
# Number of token in vocabulary
|
|
n_tokens: int
|
|
# Number of classes
|
|
n_classes: int = 'ag_news'
|
|
# Tokenizer
|
|
tokenizer: Callable = 'character'
|
|
|
|
# Whether to periodically save models
|
|
is_save_models = True
|
|
|
|
# Loss function
|
|
loss_func = nn.CrossEntropyLoss()
|
|
# Accuracy function
|
|
accuracy = Accuracy()
|
|
# Model embedding size
|
|
d_model: int = 512
|
|
# Gradient clipping
|
|
grad_norm_clip: float = 1.0
|
|
|
|
# Training data loader
|
|
train_loader: DataLoader = 'ag_news'
|
|
# Validation data loader
|
|
valid_loader: DataLoader = 'ag_news'
|
|
|
|
# Whether to log model parameters and gradients (once per epoch).
|
|
# These are summarized stats per layer, but it could still lead
|
|
# to many indicators for very deep networks.
|
|
is_log_model_params_grads: bool = False
|
|
|
|
# Whether to log model activations (once per epoch).
|
|
# These are summarized stats per layer, but it could still lead
|
|
# to many indicators for very deep networks.
|
|
is_log_model_activations: bool = False
|
|
|
|
def init(self):
|
|
"""
|
|
### Initialization
|
|
"""
|
|
# Set tracker configurations
|
|
tracker.set_scalar("accuracy.*", True)
|
|
tracker.set_scalar("loss.*", True)
|
|
# Add accuracy as a state module.
|
|
# The name is probably confusing, since it's meant to store
|
|
# states between training and validation for RNNs.
|
|
# This will keep the accuracy metric stats separate for training and validation.
|
|
self.state_modules = [self.accuracy]
|
|
|
|
def step(self, batch: any, batch_idx: BatchIndex):
|
|
"""
|
|
### Training or validation step
|
|
"""
|
|
|
|
# Move data to the device
|
|
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
|
|
|
# Update global step (number of tokens processed) when in training mode
|
|
if self.mode.is_train:
|
|
tracker.add_global_step(data.shape[1])
|
|
|
|
# Get model outputs.
|
|
# It's returning a tuple for states when using RNNs.
|
|
# This is not implemented yet. 😜
|
|
output, *_ = self.model(data)
|
|
|
|
# Calculate and log loss
|
|
loss = self.loss_func(output, target)
|
|
tracker.add("loss.", loss)
|
|
|
|
# Calculate and log accuracy
|
|
self.accuracy(output, target)
|
|
self.accuracy.track()
|
|
|
|
# Train the model
|
|
if self.mode.is_train:
|
|
# Calculate gradients
|
|
loss.backward()
|
|
# Clip gradients
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
|
|
# Take optimizer step
|
|
self.optimizer.step()
|
|
# Log the model parameters and gradients on last batch of every epoch
|
|
if batch_idx.is_last and self.is_log_model_params_grads:
|
|
tracker.add('model', self.model)
|
|
# Clear the gradients
|
|
self.optimizer.zero_grad()
|
|
|
|
# Save the tracked metrics
|
|
tracker.save()
|
|
|
|
|
|
@option(NLPClassificationConfigs.optimizer)
|
|
def _optimizer(c: NLPClassificationConfigs):
|
|
"""
|
|
### Default [optimizer configurations](../optimizers/configs.html)
|
|
"""
|
|
|
|
optimizer = OptimizerConfigs()
|
|
optimizer.parameters = c.model.parameters()
|
|
optimizer.optimizer = 'Adam'
|
|
optimizer.d_model = c.d_model
|
|
|
|
return optimizer
|
|
|
|
|
|
@option(NLPClassificationConfigs.tokenizer)
|
|
def basic_english():
|
|
"""
|
|
### Basic english tokenizer
|
|
|
|
We use character level tokenizer in this experiment.
|
|
You can switch by setting,
|
|
|
|
```
|
|
'tokenizer': 'basic_english',
|
|
```
|
|
|
|
in the configurations dictionary when starting the experiment.
|
|
|
|
"""
|
|
from torchtext.data import get_tokenizer
|
|
return get_tokenizer('basic_english')
|
|
|
|
|
|
def character_tokenizer(x: str):
|
|
"""
|
|
### Character level tokenizer
|
|
"""
|
|
return list(x)
|
|
|
|
|
|
@option(NLPClassificationConfigs.tokenizer)
|
|
def character():
|
|
"""
|
|
Character level tokenizer configuration
|
|
"""
|
|
return character_tokenizer
|
|
|
|
|
|
@option(NLPClassificationConfigs.n_tokens)
|
|
def _n_tokens(c: NLPClassificationConfigs):
|
|
"""
|
|
Get number of tokens
|
|
"""
|
|
return len(c.vocab) + 2
|
|
|
|
|
|
class CollateFunc:
|
|
"""
|
|
## Function to load data into batches
|
|
"""
|
|
|
|
def __init__(self, tokenizer, vocab: Vocab, seq_len: int, padding_token: int, classifier_token: int):
|
|
"""
|
|
* `tokenizer` is the tokenizer function
|
|
* `vocab` is the vocabulary
|
|
* `seq_len` is the length of the sequence
|
|
* `padding_token` is the token used for padding when the `seq_len` is larger than the text length
|
|
* `classifier_token` is the `[CLS]` token which we set at end of the input
|
|
"""
|
|
self.classifier_token = classifier_token
|
|
self.padding_token = padding_token
|
|
self.seq_len = seq_len
|
|
self.vocab = vocab
|
|
self.tokenizer = tokenizer
|
|
|
|
def __call__(self, batch):
|
|
"""
|
|
* `batch` is the batch of data collected by the `DataLoader`
|
|
"""
|
|
|
|
# Input data tensor, initialized with `padding_token`
|
|
data = torch.full((self.seq_len, len(batch)), self.padding_token, dtype=torch.long)
|
|
# Empty labels tensor
|
|
labels = torch.zeros(len(batch), dtype=torch.long)
|
|
|
|
# Loop through the samples
|
|
for (i, (_label, _text)) in enumerate(batch):
|
|
# Set the label
|
|
labels[i] = int(_label) - 1
|
|
# Tokenize the input text
|
|
_text = [self.vocab[token] for token in self.tokenizer(_text)]
|
|
# Truncate upto `seq_len`
|
|
_text = _text[:self.seq_len]
|
|
# Transpose and add to data
|
|
data[:len(_text), i] = data.new_tensor(_text)
|
|
|
|
# Set the final token in the sequence to `[CLS]`
|
|
data[-1, :] = self.classifier_token
|
|
|
|
#
|
|
return data, labels
|
|
|
|
|
|
@option([NLPClassificationConfigs.n_classes,
|
|
NLPClassificationConfigs.vocab,
|
|
NLPClassificationConfigs.train_loader,
|
|
NLPClassificationConfigs.valid_loader])
|
|
def ag_news(c: NLPClassificationConfigs):
|
|
"""
|
|
### AG News dataset
|
|
|
|
This loads the AG News dataset and the set the values for
|
|
`n_classes`, `vocab`, `train_loader`, and `valid_loader`.
|
|
"""
|
|
|
|
# Get training and validation datasets
|
|
train, valid = torchtext.datasets.AG_NEWS(root=str(lab.get_data_path() / 'ag_news'), split=('train', 'test'))
|
|
|
|
# Load data to memory
|
|
with monit.section('Load data'):
|
|
from labml_nn.utils import MapStyleDataset
|
|
|
|
# Create [map-style datasets](../utils.html#map_style_dataset)
|
|
train, valid = MapStyleDataset(train), MapStyleDataset(valid)
|
|
|
|
# Get tokenizer
|
|
tokenizer = c.tokenizer
|
|
|
|
# Create a counter
|
|
counter = Counter()
|
|
# Collect tokens from training dataset
|
|
for (label, line) in train:
|
|
counter.update(tokenizer(line))
|
|
# Collect tokens from validation dataset
|
|
for (label, line) in valid:
|
|
counter.update(tokenizer(line))
|
|
# Create vocabulary
|
|
vocab = torchtext.vocab.vocab(counter, min_freq=1)
|
|
|
|
# Create training data loader
|
|
train_loader = DataLoader(train, batch_size=c.batch_size, shuffle=True,
|
|
collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
|
|
# Create validation data loader
|
|
valid_loader = DataLoader(valid, batch_size=c.batch_size, shuffle=True,
|
|
collate_fn=CollateFunc(tokenizer, vocab, c.seq_len, len(vocab), len(vocab) + 1))
|
|
|
|
# Return `n_classes`, `vocab`, `train_loader`, and `valid_loader`
|
|
return 4, vocab, train_loader, valid_loader
|