mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-27 20:24:17 +08:00
183 lines
5.5 KiB
Python
183 lines
5.5 KiB
Python
from typing import Optional, Set, List
|
|
|
|
import torch.nn as nn
|
|
import torch.optim
|
|
import torch.utils.data
|
|
from torch.cuda import amp
|
|
from torch.cuda.amp import GradScaler
|
|
|
|
from labml import monit, tracker
|
|
from labml.configs import BaseConfigs, option
|
|
from labml_nn.neox.utils.finetune import FineTuner
|
|
|
|
|
|
def get_trainable_params(model: nn.Module):
|
|
"""
|
|
### Get trainable parameters
|
|
|
|
:param model: is the model to train
|
|
:return: a list of parameters for training
|
|
"""
|
|
|
|
# Get all parameters
|
|
params = list(model.parameters())
|
|
# Filter parameters that require gradients
|
|
trainable_params = [p for p in params if p.requires_grad]
|
|
|
|
#
|
|
return trainable_params
|
|
|
|
|
|
class TrainerConf(BaseConfigs):
|
|
model: nn.Module
|
|
layers: List[nn.Module]
|
|
optimizer: torch.optim.Optimizer = 'Adam'
|
|
train_loader: torch.utils.data.DataLoader
|
|
valid_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
device: torch.device = torch.device('cuda:0')
|
|
scaler: Optional[GradScaler] = 'Default'
|
|
is_amp: bool = True
|
|
dtype: torch.dtype = torch.float16
|
|
|
|
is_clone_layers: bool = True
|
|
|
|
loss_func: nn.Module = nn.CrossEntropyLoss()
|
|
checkpoints_per_epoch: int = 0
|
|
samples_per_epoch: int = 0
|
|
|
|
grad_norm: Optional[float] = 1.0
|
|
learning_rate: float = 3e-4
|
|
max_seq_len: int = 1024
|
|
batch_size: int = 64
|
|
epochs: int = 16
|
|
|
|
n_gpus: int = torch.cuda.device_count()
|
|
|
|
filter_layers: Optional[Set] = None
|
|
|
|
def get_loss(self, sample, dataset_split: str):
|
|
"""
|
|
:param dataset_split: train/valid
|
|
:param sample: is the sample
|
|
:return: the loss, output and the target
|
|
"""
|
|
data, target = sample
|
|
|
|
# Forward pass
|
|
with monit.section('Forward pass'):
|
|
output = self.model(data.to(self.device))
|
|
# Move targets to the same device as output
|
|
target = target.to(output.device)
|
|
# Calculate loss
|
|
loss = self.loss_func(output.view(target.numel(), -1), target.view(-1))
|
|
|
|
return loss, output, target
|
|
|
|
def train(self):
|
|
for epoch in monit.loop(self.epochs):
|
|
self.train_epoch()
|
|
tracker.new_line()
|
|
|
|
def sample(self, idx):
|
|
pass
|
|
|
|
def save_checkpoint(self, idx):
|
|
pass
|
|
|
|
def get_iterators(self):
|
|
# Iterate through the batches
|
|
iterators = [('train', self.train_loader)]
|
|
if self.valid_loader is not None:
|
|
iterators.append(('valid', self.valid_loader))
|
|
|
|
if self.samples_per_epoch > 0:
|
|
iterators.append((self.sample, [i for i in range(self.samples_per_epoch)]))
|
|
|
|
if self.checkpoints_per_epoch > 0:
|
|
iterators.append((self.save_checkpoint, [i for i in range(self.checkpoints_per_epoch)]))
|
|
|
|
return iterators
|
|
|
|
def train_epoch(self):
|
|
# Set model for train
|
|
self.model.train()
|
|
|
|
iterators = self.get_iterators()
|
|
for split_name, sample in monit.mix(1024, *iterators):
|
|
if split_name == 'train':
|
|
# Set gradients to zero
|
|
self.optimizer.zero_grad()
|
|
tracker.add_global_step()
|
|
|
|
with torch.set_grad_enabled(split_name == 'train'):
|
|
if self.is_amp:
|
|
# Forward pass
|
|
with amp.autocast():
|
|
loss, output, target = self.get_loss(sample, split_name)
|
|
else:
|
|
loss, output, target = self.get_loss(sample, split_name)
|
|
|
|
# Get predictions
|
|
pred = output.argmax(dim=-1)
|
|
# Calculate accuracy
|
|
accuracy = pred.eq(target).sum().item() / (target != -100).sum()
|
|
|
|
tracker.add({f'loss.{split_name}': loss, f'acc.{split_name}': accuracy * 100})
|
|
|
|
if split_name == 'train':
|
|
if self.scaler is not None:
|
|
# Backward pass
|
|
loss = self.scaler.scale(loss)
|
|
# tracker.add({'loss.scaled': loss})
|
|
|
|
with monit.section('Backward pass'):
|
|
loss.backward()
|
|
|
|
# Optimize
|
|
with monit.section('Optimize'):
|
|
if self.scaler is None:
|
|
self.optimizer.step()
|
|
else:
|
|
self.scaler.unscale_(self.optimizer)
|
|
if self.grad_norm is not None:
|
|
torch.nn.utils.clip_grad_norm_(get_trainable_params(self.model), self.grad_norm)
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
tracker.save()
|
|
|
|
|
|
@option(TrainerConf.optimizer, 'Adam')
|
|
def adam_optimizer(c: TrainerConf):
|
|
if c.dtype == torch.float32:
|
|
return torch.optim.Adam(get_trainable_params(c.model), lr=c.learning_rate)
|
|
elif c.dtype == torch.float16:
|
|
from labml_nn.optimizers.adam_fp16 import AdamFP16
|
|
return AdamFP16(get_trainable_params(c.model), lr=c.learning_rate)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
@option(TrainerConf.optimizer, 'SGD')
|
|
def sgd_optimizer(c: TrainerConf):
|
|
return torch.optim.SGD(get_trainable_params(c.model), lr=c.learning_rate)
|
|
|
|
|
|
@option(TrainerConf.scaler, 'Default')
|
|
def grad_scaler(c: TrainerConf):
|
|
if not c.is_amp:
|
|
return None
|
|
|
|
if c.dtype == torch.float16:
|
|
from labml_nn.optimizers.adam_fp16 import GradScalerFP16
|
|
return GradScalerFP16()
|
|
else:
|
|
return GradScaler()
|
|
|
|
|
|
class PipelineParallelTrainerConf(TrainerConf):
|
|
is_checkpointing: bool = False
|
|
chunks: int
|
|
|
|
fine_tuner: FineTuner
|