Files
2022-08-11 15:44:13 +05:30

183 lines
5.5 KiB
Python

from typing import Optional, Set, List
import torch.nn as nn
import torch.optim
import torch.utils.data
from torch.cuda import amp
from torch.cuda.amp import GradScaler
from labml import monit, tracker
from labml.configs import BaseConfigs, option
from labml_nn.neox.utils.finetune import FineTuner
def get_trainable_params(model: nn.Module):
"""
### Get trainable parameters
:param model: is the model to train
:return: a list of parameters for training
"""
# Get all parameters
params = list(model.parameters())
# Filter parameters that require gradients
trainable_params = [p for p in params if p.requires_grad]
#
return trainable_params
class TrainerConf(BaseConfigs):
model: nn.Module
layers: List[nn.Module]
optimizer: torch.optim.Optimizer = 'Adam'
train_loader: torch.utils.data.DataLoader
valid_loader: Optional[torch.utils.data.DataLoader] = None,
device: torch.device = torch.device('cuda:0')
scaler: Optional[GradScaler] = 'Default'
is_amp: bool = True
dtype: torch.dtype = torch.float16
is_clone_layers: bool = True
loss_func: nn.Module = nn.CrossEntropyLoss()
checkpoints_per_epoch: int = 0
samples_per_epoch: int = 0
grad_norm: Optional[float] = 1.0
learning_rate: float = 3e-4
max_seq_len: int = 1024
batch_size: int = 64
epochs: int = 16
n_gpus: int = torch.cuda.device_count()
filter_layers: Optional[Set] = None
def get_loss(self, sample, dataset_split: str):
"""
:param dataset_split: train/valid
:param sample: is the sample
:return: the loss, output and the target
"""
data, target = sample
# Forward pass
with monit.section('Forward pass'):
output = self.model(data.to(self.device))
# Move targets to the same device as output
target = target.to(output.device)
# Calculate loss
loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))
return loss, output, target
def train(self):
for epoch in monit.loop(self.epochs):
self.train_epoch()
tracker.new_line()
def sample(self, idx):
pass
def save_checkpoint(self, idx):
pass
def get_iterators(self):
# Iterate through the batches
iterators = [('train', self.train_loader)]
if self.valid_loader is not None:
iterators.append(('valid', self.valid_loader))
if self.samples_per_epoch > 0:
iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))
if self.checkpoints_per_epoch > 0:
iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))
return iterators
def train_epoch(self):
# Set model for train
self.model.train()
iterators = self.get_iterators()
for split_name, sample in monit.mix(1024, *iterators):
if split_name == 'train':
# Set gradients to zero
self.optimizer.zero_grad()
tracker.add_global_step()
with torch.set_grad_enabled(split_name == 'train'):
if self.is_amp:
# Forward pass
with amp.autocast():
loss, output, target = self.get_loss(sample, split_name)
else:
loss, output, target = self.get_loss(sample, split_name)
# Get predictions
pred = output.argmax(dim=-1)
# Calculate accuracy
accuracy = pred.eq(target).sum().item() / (target != -100).sum()
tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})
if split_name == 'train':
if self.scaler is not None:
# Backward pass
loss = self.scaler.scale(loss)
# tracker.add({'loss.scaled': loss})
with monit.section('Backward pass'):
loss.backward()
# Optimize
with monit.section('Optimize'):
if self.scaler is None:
self.optimizer.step()
else:
self.scaler.unscale_(self.optimizer)
if self.grad_norm is not None:
torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
tracker.save()
@option(TrainerConf.optimizer, 'Adam')
def adam_optimizer(c: TrainerConf):
if c.dtype == torch.float32:
return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)
elif c.dtype == torch.float16:
from labml_nn.optimizers.adam_fp16 import AdamFP16
return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)
else:
raise NotImplementedError()
@option(TrainerConf.optimizer, 'SGD')
def sgd_optimizer(c: TrainerConf):
return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)
@option(TrainerConf.scaler, 'Default')
def grad_scaler(c: TrainerConf):
if not c.is_amp:
return None
if c.dtype == torch.float16:
from labml_nn.optimizers.adam_fp16 import GradScalerFP16
return GradScalerFP16()
else:
return GradScaler()
class PipelineParallelTrainerConf(TrainerConf):
is_checkpointing: bool = False
chunks: int
fine_tuner: FineTuner