mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-18 03:41:07 +08:00
559 lines
17 KiB
Python
559 lines
17 KiB
Python
import signal
|
|
import typing
|
|
from typing import Dict, List, Callable
|
|
from typing import Optional, Tuple, Any, Collection
|
|
|
|
import labml.utils.pytorch as pytorch_utils
|
|
import torch.optim
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import torch.utils.data
|
|
from labml import tracker, logger, experiment, monit
|
|
from labml.configs import BaseConfigs, meta_config, option
|
|
from labml.internal.monitor import Loop
|
|
from labml.logger import Text
|
|
from torch import nn
|
|
from .device import DeviceConfigs
|
|
from .metrics import StateModule
|
|
|
|
|
|
class TrainingLoopIterator(Collection):
|
|
def __init__(self, start: int, total: int, step: Optional[int]):
|
|
self.step = step
|
|
self.total = total
|
|
self.start = start
|
|
self.i = None
|
|
|
|
def __iter__(self):
|
|
self.i = None
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.step is not None:
|
|
if self.i is None:
|
|
self.i = self.start
|
|
else:
|
|
self.i += self.step
|
|
else:
|
|
if self.i is None:
|
|
self.i = 0
|
|
else:
|
|
self.i += 1
|
|
|
|
if self.i >= self.total:
|
|
raise StopIteration()
|
|
|
|
if self.step is None:
|
|
return tracker.get_global_step()
|
|
else:
|
|
return self.i
|
|
|
|
def __len__(self) -> int:
|
|
if self.step is not None:
|
|
return (self.total - self.start) // self.step
|
|
else:
|
|
return self.total
|
|
|
|
def __contains__(self, x: object) -> bool:
|
|
return False
|
|
|
|
|
|
class TrainingLoop:
|
|
_iter: Optional[TrainingLoopIterator]
|
|
__loop: Loop
|
|
__signal_received: Optional[Tuple[Any, Any]]
|
|
|
|
def __init__(self, *,
|
|
loop_count: int,
|
|
loop_step: Optional[int],
|
|
log_new_line_interval: int,
|
|
log_write_interval: int,
|
|
is_loop_on_interrupt: bool):
|
|
self.__loop_count = loop_count
|
|
self.__loop_step = loop_step
|
|
self.__log_new_line_interval = log_new_line_interval
|
|
self.__log_write_interval = log_write_interval
|
|
self.__last_write_step = 0
|
|
self.__last_new_line_step = 0
|
|
self.__last_save_step = 0
|
|
self.__signal_received = None
|
|
self.__is_loop_on_interrupt = is_loop_on_interrupt
|
|
self._iter = None
|
|
|
|
def __iter__(self):
|
|
self._iter = TrainingLoopIterator(tracker.get_global_step(),
|
|
self.__loop_count,
|
|
self.__loop_step)
|
|
|
|
self.__loop = monit.loop(typing.cast(Collection, self._iter))
|
|
|
|
iter(self.__loop)
|
|
try:
|
|
self.old_handler = signal.signal(signal.SIGINT, self.__handler)
|
|
except ValueError:
|
|
pass
|
|
return self
|
|
|
|
@property
|
|
def idx(self):
|
|
if not self._iter:
|
|
return 0
|
|
if not self._iter.i:
|
|
return 0
|
|
if self.__loop_step is None:
|
|
return self._iter.i
|
|
return self._iter.i / self.__loop_step
|
|
|
|
def __finish(self):
|
|
try:
|
|
signal.signal(signal.SIGINT, self.old_handler)
|
|
except ValueError:
|
|
pass
|
|
tracker.save()
|
|
tracker.new_line()
|
|
|
|
def __next__(self):
|
|
if self.__signal_received is not None:
|
|
logger.log('\nKilling Loop.', Text.danger)
|
|
monit.finish_loop()
|
|
self.__finish()
|
|
raise StopIteration("SIGINT")
|
|
|
|
try:
|
|
global_step = next(self.__loop)
|
|
except StopIteration as e:
|
|
self.__finish()
|
|
raise e
|
|
|
|
tracker.set_global_step(global_step)
|
|
|
|
if global_step - self.__last_write_step >= self.__log_write_interval:
|
|
tracker.save()
|
|
self.__last_write_step = global_step
|
|
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
|
|
tracker.new_line()
|
|
self.__last_new_line_step = global_step
|
|
|
|
return global_step
|
|
|
|
def __handler(self, sig, frame):
|
|
# Pass second interrupt without delaying
|
|
if self.__signal_received is not None:
|
|
logger.log('\nSIGINT received twice. Stopping...', Text.danger)
|
|
self.old_handler(*self.__signal_received)
|
|
return
|
|
|
|
if self.__is_loop_on_interrupt:
|
|
# Store the interrupt signal for later
|
|
self.__signal_received = (sig, frame)
|
|
logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
|
|
else:
|
|
self.__finish()
|
|
logger.log('Killing loop...', Text.danger)
|
|
self.old_handler(sig, frame)
|
|
|
|
def __str__(self):
|
|
return "LabTrainingLoop"
|
|
|
|
|
|
class TrainingLoopConfigs(BaseConfigs):
|
|
r"""
|
|
This is a configurable training loop. You can extend this class for your configurations
|
|
if it involves a training loop.
|
|
|
|
>>> for step in conf.training_loop:
|
|
>>> ...
|
|
|
|
Arguments:
|
|
loop_count (int): Total number of steps. Defaults to ``10``.
|
|
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
|
|
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
|
|
Defaults to ``1``.
|
|
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
|
|
Defaults to ``1``.
|
|
is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.
|
|
Defaults to ``False``.
|
|
"""
|
|
loop_count: int = 10
|
|
loop_step: int = 1
|
|
log_new_line_interval: int = 1
|
|
log_write_interval: int = 1
|
|
is_loop_on_interrupt: bool = False
|
|
|
|
training_loop: TrainingLoop
|
|
|
|
|
|
@option(TrainingLoopConfigs.training_loop)
|
|
def _loop_configs(c: TrainingLoopConfigs):
|
|
return TrainingLoop(loop_count=c.loop_count,
|
|
loop_step=c.loop_step,
|
|
log_new_line_interval=c.log_new_line_interval,
|
|
log_write_interval=c.log_write_interval,
|
|
is_loop_on_interrupt=c.is_loop_on_interrupt)
|
|
|
|
|
|
meta_config(TrainingLoopConfigs.loop_step,
|
|
TrainingLoopConfigs.loop_count,
|
|
TrainingLoopConfigs.log_new_line_interval,
|
|
TrainingLoopConfigs.log_write_interval,
|
|
TrainingLoopConfigs.is_loop_on_interrupt)
|
|
|
|
|
|
class ModeState:
|
|
def __init__(self):
|
|
self._rollback_stack = []
|
|
|
|
self.is_train = False
|
|
self.is_log_activations = False
|
|
self.is_log_parameters = False
|
|
self.is_optimize = False
|
|
|
|
def _enter(self, mode: Dict[str, any]):
|
|
rollback = {}
|
|
for k, v in mode.items():
|
|
if v is None:
|
|
continue
|
|
rollback[k] = getattr(self, k)
|
|
setattr(self, k, v)
|
|
|
|
self._rollback_stack.append(rollback)
|
|
|
|
return len(self._rollback_stack)
|
|
|
|
def _exit(self, n: int):
|
|
assert n == len(self._rollback_stack)
|
|
|
|
rollback = self._rollback_stack[-1]
|
|
self._rollback_stack.pop(-1)
|
|
|
|
for k, v in rollback.items():
|
|
setattr(self, k, v)
|
|
|
|
def update(self, *,
|
|
is_train: Optional[bool] = None,
|
|
is_log_parameters: Optional[bool] = None,
|
|
is_log_activations: Optional[bool] = None,
|
|
is_optimize: Optional[bool] = None):
|
|
return Mode(self,
|
|
is_train=is_train,
|
|
is_log_parameters=is_log_parameters,
|
|
is_log_activations=is_log_activations,
|
|
is_optimize=is_optimize)
|
|
|
|
|
|
class Mode:
|
|
def __init__(self, mode: ModeState, **kwargs: any):
|
|
self.mode = mode
|
|
self.update = {}
|
|
for k, v in kwargs.items():
|
|
if v is not None:
|
|
self.update[k] = v
|
|
|
|
self.idx = -1
|
|
|
|
def __enter__(self):
|
|
self.idx = self.mode._enter(self.update)
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.mode._exit(self.idx)
|
|
|
|
|
|
class ForwardHook:
|
|
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
|
|
self.mode = mode
|
|
self.model_name = model_name
|
|
self.name = name
|
|
self.module = module
|
|
module.register_forward_hook(self)
|
|
|
|
def save(self, name: str, output):
|
|
if isinstance(output, torch.Tensor):
|
|
pytorch_utils.store_var(name, output)
|
|
elif isinstance(output, tuple):
|
|
for i, o in enumerate(output):
|
|
self.save(f"{name}.{i}", o)
|
|
|
|
def __call__(self, module, i, o):
|
|
if not self.mode.is_log_activations:
|
|
return
|
|
|
|
self.save(f"module.{self.model_name}.{self.name}", o)
|
|
|
|
|
|
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
|
|
for name, module in model.named_modules():
|
|
if name == '':
|
|
name = 'full'
|
|
ForwardHook(mode, model_name, name, module)
|
|
|
|
|
|
class Trainer:
|
|
def __init__(self, *,
|
|
name: str,
|
|
mode: ModeState,
|
|
data_loader: torch.utils.data.DataLoader,
|
|
inner_iterations: int,
|
|
state_modules: List[StateModule],
|
|
is_track_time: bool,
|
|
step: Callable[[any, 'BatchIndex'], None]):
|
|
self.is_track_time = is_track_time
|
|
self.mode = mode
|
|
self.name = name
|
|
self.step = step
|
|
self.state_modules = state_modules
|
|
self.__iterable = None
|
|
self.__states = [sm.create_state() for sm in self.state_modules]
|
|
self.inner_iterations = inner_iterations
|
|
self.data_loader = data_loader
|
|
self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
|
|
|
|
def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
|
|
self.data_loader = data_loader
|
|
self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
|
|
self.__iterable = None
|
|
|
|
def __call__(self):
|
|
for sm, s in zip(self.state_modules, self.__states):
|
|
sm.set_state(s)
|
|
|
|
if self.__iterable is None or self._batch_index.completed:
|
|
self.__iterable = iter(self.data_loader)
|
|
self._batch_index.reset(len(self.data_loader), self.inner_iterations)
|
|
for sm in self.state_modules:
|
|
sm.on_epoch_start()
|
|
with torch.set_grad_enabled(self.mode.is_train):
|
|
self.__iterate()
|
|
|
|
if self._batch_index.completed:
|
|
for sm in self.state_modules:
|
|
sm.on_epoch_end()
|
|
|
|
def __iterate(self):
|
|
with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
|
|
if self._batch_index.idx == 0:
|
|
monit.progress(0)
|
|
while not self._batch_index.iteration_completed:
|
|
batch = next(self.__iterable)
|
|
|
|
self.step(batch, self._batch_index)
|
|
|
|
self._batch_index.step()
|
|
monit.progress(self._batch_index.epoch_progress)
|
|
|
|
self._batch_index.step_inner()
|
|
|
|
|
|
class BatchIndex:
|
|
idx: int
|
|
total: int
|
|
iteration: int
|
|
total_iterations: int
|
|
|
|
def __init__(self, total: int, total_iterations: int):
|
|
self.total_iterations = total_iterations
|
|
self.total = total
|
|
|
|
def is_interval(self, interval: int):
|
|
if interval <= 0:
|
|
return False
|
|
if self.idx + 1 == self.total:
|
|
return True
|
|
else:
|
|
return (self.idx + 1) % interval == 0
|
|
|
|
@property
|
|
def is_last(self):
|
|
return self.idx + 1 == self.total
|
|
|
|
@property
|
|
def completed(self):
|
|
return self.iteration >= self.total_iterations
|
|
|
|
@property
|
|
def iteration_completed(self):
|
|
# // is important so that the last step happens on the last iteration
|
|
return self.idx >= (self.iteration + 1) * self.total // self.total_iterations
|
|
|
|
@property
|
|
def epoch_progress(self):
|
|
return self.idx / self.total
|
|
|
|
def step(self):
|
|
self.idx += 1
|
|
|
|
def step_inner(self):
|
|
self.iteration += 1
|
|
|
|
def reset(self, total: int, total_iterations: int):
|
|
self.total = total
|
|
self.total_iterations = total_iterations
|
|
self.idx = 0
|
|
self.iteration = 0
|
|
|
|
|
|
class TrainValidConfigs(TrainingLoopConfigs):
|
|
r"""
|
|
This is a configurable module that you can extend for experiments that involve a
|
|
training and validation datasets (i.e. most DL experiments).
|
|
|
|
Arguments:
|
|
epochs (int): Number of epochs to train on. Defaults to ``10``.
|
|
train_loader (torch.utils.data.DataLoader): Training data loader.
|
|
valid_loader (torch.utils.data.DataLoader): Training data loader.
|
|
inner_iterations (int): Number of times to switch between training and validation
|
|
within an epoch. Defaults to ``1``.
|
|
|
|
You can override ``init``, ``step`` functions. There is also a ``sample`` function
|
|
that you can override to generate samples ever time it switches between training and validation.
|
|
"""
|
|
state_modules: List[StateModule]
|
|
|
|
mode: ModeState
|
|
|
|
epochs: int = 10
|
|
|
|
trainer: Trainer
|
|
validator: Trainer
|
|
train_loader: torch.utils.data.DataLoader
|
|
valid_loader: torch.utils.data.DataLoader
|
|
|
|
loop_count = '_data_loop_count'
|
|
loop_step = None
|
|
|
|
inner_iterations: int = 1
|
|
|
|
is_track_time: bool = False
|
|
|
|
def init(self):
|
|
pass
|
|
|
|
def step(self, batch: Any, batch_idx: BatchIndex):
|
|
raise NotImplementedError
|
|
|
|
def run_step(self):
|
|
for i in range(self.inner_iterations):
|
|
with tracker.namespace('sample'):
|
|
self.sample()
|
|
with self.mode.update(is_train=True):
|
|
with tracker.namespace('train'):
|
|
self.trainer()
|
|
if self.validator:
|
|
with tracker.namespace('valid'):
|
|
self.validator()
|
|
tracker.save()
|
|
|
|
def run(self):
|
|
with monit.section("Initialize"):
|
|
self.init()
|
|
_ = self.validator
|
|
_ = self.trainer
|
|
for _ in self.training_loop:
|
|
self.run_step()
|
|
|
|
def sample(self):
|
|
pass
|
|
|
|
|
|
@option(TrainValidConfigs.trainer)
|
|
def _default_trainer(c: TrainValidConfigs):
|
|
return Trainer(name='Train',
|
|
mode=c.mode,
|
|
data_loader=c.train_loader,
|
|
inner_iterations=c.inner_iterations,
|
|
state_modules=c.state_modules,
|
|
is_track_time=c.is_track_time,
|
|
step=c.step)
|
|
|
|
|
|
@option(TrainValidConfigs.validator)
|
|
def _default_validator(c: TrainValidConfigs):
|
|
return Trainer(name='Valid',
|
|
mode=c.mode,
|
|
data_loader=c.valid_loader,
|
|
inner_iterations=c.inner_iterations,
|
|
state_modules=c.state_modules,
|
|
is_track_time=c.is_track_time,
|
|
step=c.step)
|
|
|
|
|
|
@option(TrainValidConfigs.loop_count)
|
|
def _data_loop_count(c: TrainValidConfigs):
|
|
return c.epochs
|
|
|
|
|
|
class SimpleTrainValidConfigs(TrainValidConfigs):
|
|
r"""
|
|
This is a configurable module that works for many standard DL experiments.
|
|
|
|
Arguments:
|
|
model: A PyTorch model.
|
|
optimizer: A PyTorch optimizer to update model.
|
|
device: The device to train the model on. This defaults to a configurable device
|
|
loss_function: A function to calculate the loss. This should accept ``model_output, target`` as
|
|
arguments.
|
|
update_batches (int): Number of batches to accumulate before taking an optimizer step.
|
|
Defaults to ``1``.
|
|
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
|
|
Defaults to a large number; i.e. logs every epoch.
|
|
log_activations_batches (int): How often to log model activations.
|
|
Defaults to a large number; i.e. logs every epoch.
|
|
log_save_batches (int): How often to call :func:`labml.tracker.save`.
|
|
"""
|
|
optimizer: torch.optim.Adam
|
|
model: nn.Module
|
|
device: torch.device = DeviceConfigs()
|
|
|
|
loss_func: nn.Module
|
|
|
|
update_batches: int = 1
|
|
log_params_updates: int = 2 ** 32 # 0 if not
|
|
log_activations_batches: int = 2 ** 32 # 0 if not
|
|
log_save_batches: int = 1
|
|
|
|
state_modules: List[StateModule] = []
|
|
|
|
def init(self):
|
|
pass
|
|
|
|
def step(self, batch: Any, batch_idx: BatchIndex):
|
|
self.model.train(self.mode.is_train)
|
|
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
|
|
|
if self.mode.is_train:
|
|
tracker.add_global_step(len(data))
|
|
|
|
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
|
|
with monit.section("model"):
|
|
with self.mode.update(is_log_activations=is_log_activations):
|
|
output = self.model(data)
|
|
|
|
loss = self.loss_func(output, target)
|
|
tracker.add("loss.", loss)
|
|
|
|
if self.mode.is_train:
|
|
with monit.section('backward'):
|
|
loss.backward()
|
|
|
|
if batch_idx.is_interval(self.update_batches):
|
|
with monit.section('optimize'):
|
|
self.optimizer.step()
|
|
if batch_idx.is_interval(self.log_params_updates):
|
|
tracker.add('model', self.model)
|
|
self.optimizer.zero_grad()
|
|
|
|
if batch_idx.is_interval(self.log_save_batches):
|
|
tracker.save()
|
|
|
|
|
|
meta_config(SimpleTrainValidConfigs.update_batches,
|
|
SimpleTrainValidConfigs.log_params_updates,
|
|
SimpleTrainValidConfigs.log_activations_batches)
|
|
|
|
|
|
@option(SimpleTrainValidConfigs.optimizer)
|
|
def _default_optimizer(c: SimpleTrainValidConfigs):
|
|
from .optimizer import OptimizerConfigs
|
|
opt_conf = OptimizerConfigs()
|
|
opt_conf.parameters = c.model.parameters()
|
|
return opt_conf
|