Files
2025-07-20 09:10:05 +05:30

512 lines
15 KiB
Python

import signal
import typing
from typing import Dict, List, Callable
from typing import Optional, Tuple, Any, Collection
import torch.optim
import torch.optim
import torch.utils.data
import torch.utils.data
from labml import tracker, logger, monit
from labml.configs import BaseConfigs, meta_config, option
from labml.internal.monitor import Loop
from labml.logger import Text
from torch import nn
from .device import DeviceConfigs
from .metrics import StateModule
class TrainingLoopIterator(Collection):
def __init__(self, start: int, total: int, step: Optional[int]):
self.step = step
self.total = total
self.start = start
self.i = None
def __iter__(self):
self.i = None
return self
def __next__(self):
if self.step is not None:
if self.i is None:
self.i = self.start
else:
self.i += self.step
else:
if self.i is None:
self.i = 0
else:
self.i += 1
if self.i >= self.total:
raise StopIteration()
if self.step is None:
return tracker.get_global_step()
else:
return self.i
def __len__(self) -> int:
if self.step is not None:
return (self.total - self.start) // self.step
else:
return self.total
def __contains__(self, x: object) -> bool:
return False
class TrainingLoop:
_iter: Optional[TrainingLoopIterator]
__loop: Loop
__signal_received: Optional[Tuple[Any, Any]]
def __init__(self, *,
loop_count: int,
loop_step: Optional[int],
log_new_line_interval: int,
log_write_interval: int,
is_loop_on_interrupt: bool):
self.__loop_count = loop_count
self.__loop_step = loop_step
self.__log_new_line_interval = log_new_line_interval
self.__log_write_interval = log_write_interval
self.__last_write_step = 0
self.__last_new_line_step = 0
self.__last_save_step = 0
self.__signal_received = None
self.__is_loop_on_interrupt = is_loop_on_interrupt
self._iter = None
def __iter__(self):
self._iter = TrainingLoopIterator(tracker.get_global_step(),
self.__loop_count,
self.__loop_step)
self.__loop = monit.loop(typing.cast(Collection, self._iter))
iter(self.__loop)
try:
self.old_handler = signal.signal(signal.SIGINT, self.__handler)
except ValueError:
pass
return self
@property
def idx(self):
if not self._iter:
return 0
if not self._iter.i:
return 0
if self.__loop_step is None:
return self._iter.i
return self._iter.i / self.__loop_step
def __finish(self):
try:
signal.signal(signal.SIGINT, self.old_handler)
except ValueError:
pass
tracker.save()
tracker.new_line()
def __next__(self):
if self.__signal_received is not None:
logger.log('\nKilling Loop.', Text.danger)
monit.finish_loop()
self.__finish()
raise StopIteration("SIGINT")
try:
global_step = next(self.__loop)
except StopIteration as e:
self.__finish()
raise e
tracker.set_global_step(global_step)
if global_step - self.__last_write_step >= self.__log_write_interval:
tracker.save()
self.__last_write_step = global_step
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
tracker.new_line()
self.__last_new_line_step = global_step
return global_step
def __handler(self, sig, frame):
# Pass second interrupt without delaying
if self.__signal_received is not None:
logger.log('\nSIGINT received twice. Stopping...', Text.danger)
self.old_handler(*self.__signal_received)
return
if self.__is_loop_on_interrupt:
# Store the interrupt signal for later
self.__signal_received = (sig, frame)
logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
else:
self.__finish()
logger.log('Killing loop...', Text.danger)
self.old_handler(sig, frame)
def __str__(self):
return "LabTrainingLoop"
class TrainingLoopConfigs(BaseConfigs):
r"""
This is a configurable training loop. You can extend this class for your configurations
if it involves a training loop.
>>> for step in conf.training_loop:
>>> ...
Arguments:
loop_count (int): Total number of steps. Defaults to ``10``.
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
Defaults to ``1``.
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
Defaults to ``1``.
is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.
Defaults to ``False``.
"""
loop_count: int = 10
loop_step: int = 1
log_new_line_interval: int = 1
log_write_interval: int = 1
is_loop_on_interrupt: bool = False
training_loop: TrainingLoop
@option(TrainingLoopConfigs.training_loop)
def _loop_configs(c: TrainingLoopConfigs):
return TrainingLoop(loop_count=c.loop_count,
loop_step=c.loop_step,
log_new_line_interval=c.log_new_line_interval,
log_write_interval=c.log_write_interval,
is_loop_on_interrupt=c.is_loop_on_interrupt)
meta_config(TrainingLoopConfigs.loop_step,
TrainingLoopConfigs.loop_count,
TrainingLoopConfigs.log_new_line_interval,
TrainingLoopConfigs.log_write_interval,
TrainingLoopConfigs.is_loop_on_interrupt)
class ModeState:
def __init__(self):
self._rollback_stack = []
self.is_train = False
self.is_optimize = False
def _enter(self, mode: Dict[str, any]):
rollback = {}
for k, v in mode.items():
if v is None:
continue
rollback[k] = getattr(self, k)
setattr(self, k, v)
self._rollback_stack.append(rollback)
return len(self._rollback_stack)
def _exit(self, n: int):
assert n == len(self._rollback_stack)
rollback = self._rollback_stack[-1]
self._rollback_stack.pop(-1)
for k, v in rollback.items():
setattr(self, k, v)
def update(self, *,
is_train: Optional[bool] = None,
is_optimize: Optional[bool] = None):
return Mode(self,
is_train=is_train,
is_optimize=is_optimize)
class Mode:
def __init__(self, mode: ModeState, **kwargs: any):
self.mode = mode
self.update = {}
for k, v in kwargs.items():
if v is not None:
self.update[k] = v
self.idx = -1
def __enter__(self):
self.idx = self.mode._enter(self.update)
def __exit__(self, exc_type, exc_val, exc_tb):
self.mode._exit(self.idx)
class Trainer:
def __init__(self, *,
name: str,
mode: ModeState,
data_loader: torch.utils.data.DataLoader,
inner_iterations: int,
state_modules: List[StateModule],
is_track_time: bool,
step: Callable[[any, 'BatchIndex'], None]):
self.is_track_time = is_track_time
self.mode = mode
self.name = name
self.step = step
self.state_modules = state_modules
self.__iterable = None
self.__states = [sm.create_state() for sm in self.state_modules]
self.inner_iterations = inner_iterations
self.data_loader = data_loader
self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
self.data_loader = data_loader
self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
self.__iterable = None
def __call__(self):
for sm, s in zip(self.state_modules, self.__states):
sm.set_state(s)
if self.__iterable is None or self._batch_index.completed:
self.__iterable = iter(self.data_loader)
self._batch_index.reset(len(self.data_loader), self.inner_iterations)
for sm in self.state_modules:
sm.on_epoch_start()
with torch.set_grad_enabled(self.mode.is_train):
self.__iterate()
if self._batch_index.completed:
for sm in self.state_modules:
sm.on_epoch_end()
def __iterate(self):
with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
if self._batch_index.idx == 0:
monit.progress(0)
while not self._batch_index.iteration_completed:
batch = next(self.__iterable)
self.step(batch, self._batch_index)
self._batch_index.step()
monit.progress(self._batch_index.epoch_progress)
self._batch_index.step_inner()
class BatchIndex:
idx: int
total: int
iteration: int
total_iterations: int
def __init__(self, total: int, total_iterations: int):
self.total_iterations = total_iterations
self.total = total
def is_interval(self, interval: int):
if interval <= 0:
return False
if self.idx + 1 == self.total:
return True
else:
return (self.idx + 1) % interval == 0
@property
def is_last(self):
return self.idx + 1 == self.total
@property
def completed(self):
return self.iteration >= self.total_iterations
@property
def iteration_completed(self):
# // is important so that the last step happens on the last iteration
return self.idx >= (self.iteration + 1) * self.total // self.total_iterations
@property
def epoch_progress(self):
return self.idx / self.total
def step(self):
self.idx += 1
def step_inner(self):
self.iteration += 1
def reset(self, total: int, total_iterations: int):
self.total = total
self.total_iterations = total_iterations
self.idx = 0
self.iteration = 0
class TrainValidConfigs(TrainingLoopConfigs):
r"""
This is a configurable module that you can extend for experiments that involve a
training and validation datasets (i.e. most DL experiments).
Arguments:
epochs (int): Number of epochs to train on. Defaults to ``10``.
train_loader (torch.utils.data.DataLoader): Training data loader.
valid_loader (torch.utils.data.DataLoader): Training data loader.
inner_iterations (int): Number of times to switch between training and validation
within an epoch. Defaults to ``1``.
You can override ``init``, ``step`` functions. There is also a ``sample`` function
that you can override to generate samples ever time it switches between training and validation.
"""
state_modules: List[StateModule]
mode: ModeState
epochs: int = 10
trainer: Trainer
validator: Trainer
train_loader: torch.utils.data.DataLoader
valid_loader: torch.utils.data.DataLoader
loop_count = '_data_loop_count'
loop_step = None
inner_iterations: int = 1
is_track_time: bool = False
def init(self):
pass
def step(self, batch: Any, batch_idx: BatchIndex):
raise NotImplementedError
def run_step(self):
for i in range(self.inner_iterations):
with tracker.namespace('sample'):
self.sample()
with self.mode.update(is_train=True):
with tracker.namespace('train'):
self.trainer()
if self.validator:
with tracker.namespace('valid'):
self.validator()
tracker.save()
def run(self):
with monit.section("Initialize"):
self.init()
_ = self.validator
_ = self.trainer
for _ in self.training_loop:
self.run_step()
def sample(self):
pass
@option(TrainValidConfigs.trainer)
def _default_trainer(c: TrainValidConfigs):
return Trainer(name='Train',
mode=c.mode,
data_loader=c.train_loader,
inner_iterations=c.inner_iterations,
state_modules=c.state_modules,
is_track_time=c.is_track_time,
step=c.step)
@option(TrainValidConfigs.validator)
def _default_validator(c: TrainValidConfigs):
return Trainer(name='Valid',
mode=c.mode,
data_loader=c.valid_loader,
inner_iterations=c.inner_iterations,
state_modules=c.state_modules,
is_track_time=c.is_track_time,
step=c.step)
@option(TrainValidConfigs.loop_count)
def _data_loop_count(c: TrainValidConfigs):
return c.epochs
class SimpleTrainValidConfigs(TrainValidConfigs):
r"""
This is a configurable module that works for many standard DL experiments.
Arguments:
model: A PyTorch model.
optimizer: A PyTorch optimizer to update model.
device: The device to train the model on. This defaults to a configurable device
loss_function: A function to calculate the loss. This should accept ``model_output, target`` as
arguments.
update_batches (int): Number of batches to accumulate before taking an optimizer step.
Defaults to ``1``.
log_save_batches (int): How often to call :func:`labml.tracker.save`.
"""
optimizer: torch.optim.Adam
model: nn.Module
device: torch.device = DeviceConfigs()
loss_func: nn.Module
update_batches: int = 1
log_save_batches: int = 1
state_modules: List[StateModule] = []
def init(self):
pass
def step(self, batch: Any, batch_idx: BatchIndex):
self.model.train(self.mode.is_train)
data, target = batch[0].to(self.device), batch[1].to(self.device)
if self.mode.is_train:
tracker.add_global_step(len(data))
with monit.section("model"):
output = self.model(data)
loss = self.loss_func(output, target)
tracker.add("loss.", loss)
if self.mode.is_train:
with monit.section('backward'):
loss.backward()
if batch_idx.is_interval(self.update_batches):
with monit.section('optimize'):
self.optimizer.step()
self.optimizer.zero_grad()
if batch_idx.is_interval(self.log_save_batches):
tracker.save()
meta_config(SimpleTrainValidConfigs.update_batches,
)
@option(SimpleTrainValidConfigs.optimizer)
def _default_optimizer(c: SimpleTrainValidConfigs):
from .optimizer import OptimizerConfigs
opt_conf = OptimizerConfigs()
opt_conf.parameters = c.model.parameters()
return opt_conf