mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			512 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			512 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import signal
 | |
| import typing
 | |
| from typing import Dict, List, Callable
 | |
| from typing import Optional, Tuple, Any, Collection
 | |
| 
 | |
| import torch.optim
 | |
| import torch.optim
 | |
| import torch.utils.data
 | |
| import torch.utils.data
 | |
| from labml import tracker, logger, monit
 | |
| from labml.configs import BaseConfigs, meta_config, option
 | |
| from labml.internal.monitor import Loop
 | |
| from labml.logger import Text
 | |
| from torch import nn
 | |
| from .device import DeviceConfigs
 | |
| from .metrics import StateModule
 | |
| 
 | |
| 
 | |
| class TrainingLoopIterator(Collection):
 | |
|     def __init__(self, start: int, total: int, step: Optional[int]):
 | |
|         self.step = step
 | |
|         self.total = total
 | |
|         self.start = start
 | |
|         self.i = None
 | |
| 
 | |
|     def __iter__(self):
 | |
|         self.i = None
 | |
|         return self
 | |
| 
 | |
|     def __next__(self):
 | |
|         if self.step is not None:
 | |
|             if self.i is None:
 | |
|                 self.i = self.start
 | |
|             else:
 | |
|                 self.i += self.step
 | |
|         else:
 | |
|             if self.i is None:
 | |
|                 self.i = 0
 | |
|             else:
 | |
|                 self.i += 1
 | |
| 
 | |
|         if self.i >= self.total:
 | |
|             raise StopIteration()
 | |
| 
 | |
|         if self.step is None:
 | |
|             return tracker.get_global_step()
 | |
|         else:
 | |
|             return self.i
 | |
| 
 | |
|     def __len__(self) -> int:
 | |
|         if self.step is not None:
 | |
|             return (self.total - self.start) // self.step
 | |
|         else:
 | |
|             return self.total
 | |
| 
 | |
|     def __contains__(self, x: object) -> bool:
 | |
|         return False
 | |
| 
 | |
| 
 | |
| class TrainingLoop:
 | |
|     _iter: Optional[TrainingLoopIterator]
 | |
|     __loop: Loop
 | |
|     __signal_received: Optional[Tuple[Any, Any]]
 | |
| 
 | |
|     def __init__(self, *,
 | |
|                  loop_count: int,
 | |
|                  loop_step: Optional[int],
 | |
|                  log_new_line_interval: int,
 | |
|                  log_write_interval: int,
 | |
|                  is_loop_on_interrupt: bool):
 | |
|         self.__loop_count = loop_count
 | |
|         self.__loop_step = loop_step
 | |
|         self.__log_new_line_interval = log_new_line_interval
 | |
|         self.__log_write_interval = log_write_interval
 | |
|         self.__last_write_step = 0
 | |
|         self.__last_new_line_step = 0
 | |
|         self.__last_save_step = 0
 | |
|         self.__signal_received = None
 | |
|         self.__is_loop_on_interrupt = is_loop_on_interrupt
 | |
|         self._iter = None
 | |
| 
 | |
|     def __iter__(self):
 | |
|         self._iter = TrainingLoopIterator(tracker.get_global_step(),
 | |
|                                           self.__loop_count,
 | |
|                                           self.__loop_step)
 | |
| 
 | |
|         self.__loop = monit.loop(typing.cast(Collection, self._iter))
 | |
| 
 | |
|         iter(self.__loop)
 | |
|         try:
 | |
|             self.old_handler = signal.signal(signal.SIGINT, self.__handler)
 | |
|         except ValueError:
 | |
|             pass
 | |
|         return self
 | |
| 
 | |
|     @property
 | |
|     def idx(self):
 | |
|         if not self._iter:
 | |
|             return 0
 | |
|         if not self._iter.i:
 | |
|             return 0
 | |
|         if self.__loop_step is None:
 | |
|             return self._iter.i
 | |
|         return self._iter.i / self.__loop_step
 | |
| 
 | |
|     def __finish(self):
 | |
|         try:
 | |
|             signal.signal(signal.SIGINT, self.old_handler)
 | |
|         except ValueError:
 | |
|             pass
 | |
|         tracker.save()
 | |
|         tracker.new_line()
 | |
| 
 | |
|     def __next__(self):
 | |
|         if self.__signal_received is not None:
 | |
|             logger.log('\nKilling Loop.', Text.danger)
 | |
|             monit.finish_loop()
 | |
|             self.__finish()
 | |
|             raise StopIteration("SIGINT")
 | |
| 
 | |
|         try:
 | |
|             global_step = next(self.__loop)
 | |
|         except StopIteration as e:
 | |
|             self.__finish()
 | |
|             raise e
 | |
| 
 | |
|         tracker.set_global_step(global_step)
 | |
| 
 | |
|         if global_step - self.__last_write_step >= self.__log_write_interval:
 | |
|             tracker.save()
 | |
|             self.__last_write_step = global_step
 | |
|         if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
 | |
|             tracker.new_line()
 | |
|             self.__last_new_line_step = global_step
 | |
| 
 | |
|         return global_step
 | |
| 
 | |
|     def __handler(self, sig, frame):
 | |
|         # Pass second interrupt without delaying
 | |
|         if self.__signal_received is not None:
 | |
|             logger.log('\nSIGINT received twice. Stopping...', Text.danger)
 | |
|             self.old_handler(*self.__signal_received)
 | |
|             return
 | |
| 
 | |
|         if self.__is_loop_on_interrupt:
 | |
|             # Store the interrupt signal for later
 | |
|             self.__signal_received = (sig, frame)
 | |
|             logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
 | |
|         else:
 | |
|             self.__finish()
 | |
|             logger.log('Killing loop...', Text.danger)
 | |
|             self.old_handler(sig, frame)
 | |
| 
 | |
|     def __str__(self):
 | |
|         return "LabTrainingLoop"
 | |
| 
 | |
| 
 | |
| class TrainingLoopConfigs(BaseConfigs):
 | |
|     r"""
 | |
|     This is a configurable training loop. You can extend this class for your configurations
 | |
|     if it involves a training loop.
 | |
| 
 | |
|     >>> for step in conf.training_loop:
 | |
|     >>>     ...
 | |
| 
 | |
|     Arguments:
 | |
|         loop_count (int): Total number of steps. Defaults to ``10``.
 | |
|         loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
 | |
|         log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
 | |
|          Defaults to ``1``.
 | |
|         log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
 | |
|          Defaults to ``1``.
 | |
|         is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.
 | |
|          Defaults to ``False``.
 | |
|     """
 | |
|     loop_count: int = 10
 | |
|     loop_step: int = 1
 | |
|     log_new_line_interval: int = 1
 | |
|     log_write_interval: int = 1
 | |
|     is_loop_on_interrupt: bool = False
 | |
| 
 | |
|     training_loop: TrainingLoop
 | |
| 
 | |
| 
 | |
| @option(TrainingLoopConfigs.training_loop)
 | |
| def _loop_configs(c: TrainingLoopConfigs):
 | |
|     return TrainingLoop(loop_count=c.loop_count,
 | |
|                         loop_step=c.loop_step,
 | |
|                         log_new_line_interval=c.log_new_line_interval,
 | |
|                         log_write_interval=c.log_write_interval,
 | |
|                         is_loop_on_interrupt=c.is_loop_on_interrupt)
 | |
| 
 | |
| 
 | |
| meta_config(TrainingLoopConfigs.loop_step,
 | |
|             TrainingLoopConfigs.loop_count,
 | |
|             TrainingLoopConfigs.log_new_line_interval,
 | |
|             TrainingLoopConfigs.log_write_interval,
 | |
|             TrainingLoopConfigs.is_loop_on_interrupt)
 | |
| 
 | |
| 
 | |
| class ModeState:
 | |
|     def __init__(self):
 | |
|         self._rollback_stack = []
 | |
| 
 | |
|         self.is_train = False
 | |
|         self.is_optimize = False
 | |
| 
 | |
|     def _enter(self, mode: Dict[str, any]):
 | |
|         rollback = {}
 | |
|         for k, v in mode.items():
 | |
|             if v is None:
 | |
|                 continue
 | |
|             rollback[k] = getattr(self, k)
 | |
|             setattr(self, k, v)
 | |
| 
 | |
|         self._rollback_stack.append(rollback)
 | |
| 
 | |
|         return len(self._rollback_stack)
 | |
| 
 | |
|     def _exit(self, n: int):
 | |
|         assert n == len(self._rollback_stack)
 | |
| 
 | |
|         rollback = self._rollback_stack[-1]
 | |
|         self._rollback_stack.pop(-1)
 | |
| 
 | |
|         for k, v in rollback.items():
 | |
|             setattr(self, k, v)
 | |
| 
 | |
|     def update(self, *,
 | |
|                is_train: Optional[bool] = None,
 | |
|                is_optimize: Optional[bool] = None):
 | |
|         return Mode(self,
 | |
|                     is_train=is_train,
 | |
|                     is_optimize=is_optimize)
 | |
| 
 | |
| 
 | |
| class Mode:
 | |
|     def __init__(self, mode: ModeState, **kwargs: any):
 | |
|         self.mode = mode
 | |
|         self.update = {}
 | |
|         for k, v in kwargs.items():
 | |
|             if v is not None:
 | |
|                 self.update[k] = v
 | |
| 
 | |
|         self.idx = -1
 | |
| 
 | |
|     def __enter__(self):
 | |
|         self.idx = self.mode._enter(self.update)
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_val, exc_tb):
 | |
|         self.mode._exit(self.idx)
 | |
| 
 | |
| 
 | |
| class Trainer:
 | |
|     def __init__(self, *,
 | |
|                  name: str,
 | |
|                  mode: ModeState,
 | |
|                  data_loader: torch.utils.data.DataLoader,
 | |
|                  inner_iterations: int,
 | |
|                  state_modules: List[StateModule],
 | |
|                  is_track_time: bool,
 | |
|                  step: Callable[[any, 'BatchIndex'], None]):
 | |
|         self.is_track_time = is_track_time
 | |
|         self.mode = mode
 | |
|         self.name = name
 | |
|         self.step = step
 | |
|         self.state_modules = state_modules
 | |
|         self.__iterable = None
 | |
|         self.__states = [sm.create_state() for sm in self.state_modules]
 | |
|         self.inner_iterations = inner_iterations
 | |
|         self.data_loader = data_loader
 | |
|         self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
 | |
| 
 | |
|     def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
 | |
|         self.data_loader = data_loader
 | |
|         self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
 | |
|         self.__iterable = None
 | |
| 
 | |
|     def __call__(self):
 | |
|         for sm, s in zip(self.state_modules, self.__states):
 | |
|             sm.set_state(s)
 | |
| 
 | |
|         if self.__iterable is None or self._batch_index.completed:
 | |
|             self.__iterable = iter(self.data_loader)
 | |
|             self._batch_index.reset(len(self.data_loader), self.inner_iterations)
 | |
|             for sm in self.state_modules:
 | |
|                 sm.on_epoch_start()
 | |
|         with torch.set_grad_enabled(self.mode.is_train):
 | |
|             self.__iterate()
 | |
| 
 | |
|         if self._batch_index.completed:
 | |
|             for sm in self.state_modules:
 | |
|                 sm.on_epoch_end()
 | |
| 
 | |
|     def __iterate(self):
 | |
|         with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
 | |
|             if self._batch_index.idx == 0:
 | |
|                 monit.progress(0)
 | |
|             while not self._batch_index.iteration_completed:
 | |
|                 batch = next(self.__iterable)
 | |
| 
 | |
|                 self.step(batch, self._batch_index)
 | |
| 
 | |
|                 self._batch_index.step()
 | |
|                 monit.progress(self._batch_index.epoch_progress)
 | |
| 
 | |
|         self._batch_index.step_inner()
 | |
| 
 | |
| 
 | |
| class BatchIndex:
 | |
|     idx: int
 | |
|     total: int
 | |
|     iteration: int
 | |
|     total_iterations: int
 | |
| 
 | |
|     def __init__(self, total: int, total_iterations: int):
 | |
|         self.total_iterations = total_iterations
 | |
|         self.total = total
 | |
| 
 | |
|     def is_interval(self, interval: int):
 | |
|         if interval <= 0:
 | |
|             return False
 | |
|         if self.idx + 1 == self.total:
 | |
|             return True
 | |
|         else:
 | |
|             return (self.idx + 1) % interval == 0
 | |
| 
 | |
|     @property
 | |
|     def is_last(self):
 | |
|         return self.idx + 1 == self.total
 | |
| 
 | |
|     @property
 | |
|     def completed(self):
 | |
|         return self.iteration >= self.total_iterations
 | |
| 
 | |
|     @property
 | |
|     def iteration_completed(self):
 | |
|         # // is important so that the last step happens on the last iteration
 | |
|         return self.idx >= (self.iteration + 1) * self.total // self.total_iterations
 | |
| 
 | |
|     @property
 | |
|     def epoch_progress(self):
 | |
|         return self.idx / self.total
 | |
| 
 | |
|     def step(self):
 | |
|         self.idx += 1
 | |
| 
 | |
|     def step_inner(self):
 | |
|         self.iteration += 1
 | |
| 
 | |
|     def reset(self, total: int, total_iterations: int):
 | |
|         self.total = total
 | |
|         self.total_iterations = total_iterations
 | |
|         self.idx = 0
 | |
|         self.iteration = 0
 | |
| 
 | |
| 
 | |
| class TrainValidConfigs(TrainingLoopConfigs):
 | |
|     r"""
 | |
|     This is a configurable module that you can extend for experiments that involve a
 | |
|     training and validation datasets (i.e. most DL experiments).
 | |
| 
 | |
|     Arguments:
 | |
|         epochs (int): Number of epochs to train on. Defaults to ``10``.
 | |
|         train_loader (torch.utils.data.DataLoader): Training data loader.
 | |
|         valid_loader (torch.utils.data.DataLoader): Training data loader.
 | |
|         inner_iterations (int): Number of times to switch between training and validation
 | |
|          within an epoch. Defaults to ``1``.
 | |
| 
 | |
|     You can override ``init``, ``step`` functions. There is also a ``sample`` function
 | |
|     that you can override to generate samples ever time it switches between training and validation.
 | |
|     """
 | |
|     state_modules: List[StateModule]
 | |
| 
 | |
|     mode: ModeState
 | |
| 
 | |
|     epochs: int = 10
 | |
| 
 | |
|     trainer: Trainer
 | |
|     validator: Trainer
 | |
|     train_loader: torch.utils.data.DataLoader
 | |
|     valid_loader: torch.utils.data.DataLoader
 | |
| 
 | |
|     loop_count = '_data_loop_count'
 | |
|     loop_step = None
 | |
| 
 | |
|     inner_iterations: int = 1
 | |
| 
 | |
|     is_track_time: bool = False
 | |
| 
 | |
|     def init(self):
 | |
|         pass
 | |
| 
 | |
|     def step(self, batch: Any, batch_idx: BatchIndex):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def run_step(self):
 | |
|         for i in range(self.inner_iterations):
 | |
|             with tracker.namespace('sample'):
 | |
|                 self.sample()
 | |
|             with self.mode.update(is_train=True):
 | |
|                 with tracker.namespace('train'):
 | |
|                     self.trainer()
 | |
|             if self.validator:
 | |
|                 with tracker.namespace('valid'):
 | |
|                     self.validator()
 | |
|             tracker.save()
 | |
| 
 | |
|     def run(self):
 | |
|         with monit.section("Initialize"):
 | |
|             self.init()
 | |
|         _ = self.validator
 | |
|         _ = self.trainer
 | |
|         for _ in self.training_loop:
 | |
|             self.run_step()
 | |
| 
 | |
|     def sample(self):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| @option(TrainValidConfigs.trainer)
 | |
| def _default_trainer(c: TrainValidConfigs):
 | |
|     return Trainer(name='Train',
 | |
|                    mode=c.mode,
 | |
|                    data_loader=c.train_loader,
 | |
|                    inner_iterations=c.inner_iterations,
 | |
|                    state_modules=c.state_modules,
 | |
|                    is_track_time=c.is_track_time,
 | |
|                    step=c.step)
 | |
| 
 | |
| 
 | |
| @option(TrainValidConfigs.validator)
 | |
| def _default_validator(c: TrainValidConfigs):
 | |
|     return Trainer(name='Valid',
 | |
|                    mode=c.mode,
 | |
|                    data_loader=c.valid_loader,
 | |
|                    inner_iterations=c.inner_iterations,
 | |
|                    state_modules=c.state_modules,
 | |
|                    is_track_time=c.is_track_time,
 | |
|                    step=c.step)
 | |
| 
 | |
| 
 | |
| @option(TrainValidConfigs.loop_count)
 | |
| def _data_loop_count(c: TrainValidConfigs):
 | |
|     return c.epochs
 | |
| 
 | |
| 
 | |
| class SimpleTrainValidConfigs(TrainValidConfigs):
 | |
|     r"""
 | |
|     This is a configurable module that works for many standard DL experiments.
 | |
| 
 | |
|     Arguments:
 | |
|         model: A PyTorch model.
 | |
|         optimizer: A PyTorch optimizer to update model.
 | |
|         device: The device to train the model on. This defaults to a configurable device
 | |
|         loss_function: A function to calculate the loss. This should accept ``model_output, target`` as
 | |
|          arguments.
 | |
|         update_batches (int): Number of batches to accumulate before taking an optimizer step.
 | |
|          Defaults to ``1``.
 | |
|         log_save_batches (int): How often to call :func:`labml.tracker.save`.
 | |
|     """
 | |
|     optimizer: torch.optim.Adam
 | |
|     model: nn.Module
 | |
|     device: torch.device = DeviceConfigs()
 | |
| 
 | |
|     loss_func: nn.Module
 | |
| 
 | |
|     update_batches: int = 1
 | |
|     log_save_batches: int = 1
 | |
| 
 | |
|     state_modules: List[StateModule] = []
 | |
| 
 | |
|     def init(self):
 | |
|         pass
 | |
| 
 | |
|     def step(self, batch: Any, batch_idx: BatchIndex):
 | |
|         self.model.train(self.mode.is_train)
 | |
|         data, target = batch[0].to(self.device), batch[1].to(self.device)
 | |
| 
 | |
|         if self.mode.is_train:
 | |
|             tracker.add_global_step(len(data))
 | |
| 
 | |
|         with monit.section("model"):
 | |
|             output = self.model(data)
 | |
| 
 | |
|         loss = self.loss_func(output, target)
 | |
|         tracker.add("loss.", loss)
 | |
| 
 | |
|         if self.mode.is_train:
 | |
|             with monit.section('backward'):
 | |
|                 loss.backward()
 | |
| 
 | |
|             if batch_idx.is_interval(self.update_batches):
 | |
|                 with monit.section('optimize'):
 | |
|                     self.optimizer.step()
 | |
|                 self.optimizer.zero_grad()
 | |
| 
 | |
|             if batch_idx.is_interval(self.log_save_batches):
 | |
|                 tracker.save()
 | |
| 
 | |
| 
 | |
| meta_config(SimpleTrainValidConfigs.update_batches,
 | |
|             )
 | |
| 
 | |
| 
 | |
| @option(SimpleTrainValidConfigs.optimizer)
 | |
| def _default_optimizer(c: SimpleTrainValidConfigs):
 | |
|     from .optimizer import OptimizerConfigs
 | |
|     opt_conf = OptimizerConfigs()
 | |
|     opt_conf.parameters = c.model.parameters()
 | |
|     return opt_conf
 | 
