cleanup log activations

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:10:05 +05:30
parent a713c92b82
commit 5eecda7e28
12 changed files with 68 additions and 136 deletions

View File

@ -127,8 +127,6 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
# Whether to log activations
with self.mode.update(is_log_activations=batch_idx.is_last):
# Run the model # Run the model
caps, reconstructions, pred = self.model(data) caps, reconstructions, pred = self.model(data)

View File

@ -73,8 +73,6 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get model outputs. # Get model outputs.
output = self.model(data) output = self.model(data)

View File

@ -132,8 +132,6 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
# Get model outputs. # Get model outputs.
# It's returning a tuple for states when using RNNs. # It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜 # This is not implemented yet. 😜

View File

@ -108,8 +108,6 @@ class NLPClassificationConfigs(TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[1]) tracker.add_global_step(data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
# Get model outputs. # Get model outputs.
# It's returning a tuple for states when using RNNs. # It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜 # This is not implemented yet. 😜

View File

@ -315,8 +315,6 @@ class Configs(BaseConfigs):
# Accumulate gradients for `gradient_accumulate_steps` # Accumulate gradients for `gradient_accumulate_steps`
for i in range(self.gradient_accumulate_steps): for i in range(self.gradient_accumulate_steps):
# Update `mode`. Set whether to log activation
with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
# Sample images from generator # Sample images from generator
generated_images, _ = self.generate_images(self.batch_size) generated_images, _ = self.generate_images(self.batch_size)
# Discriminator classification for generated images # Discriminator classification for generated images

View File

@ -3,12 +3,11 @@ import typing
from typing import Dict, List, Callable from typing import Dict, List, Callable
from typing import Optional, Tuple, Any, Collection from typing import Optional, Tuple, Any, Collection
import labml.utils.pytorch as pytorch_utils
import torch.optim import torch.optim
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data import torch.utils.data
from labml import tracker, logger, experiment, monit from labml import tracker, logger, monit
from labml.configs import BaseConfigs, meta_config, option from labml.configs import BaseConfigs, meta_config, option
from labml.internal.monitor import Loop from labml.internal.monitor import Loop
from labml.logger import Text from labml.logger import Text
@ -204,8 +203,6 @@ class ModeState:
self._rollback_stack = [] self._rollback_stack = []
self.is_train = False self.is_train = False
self.is_log_activations = False
self.is_log_parameters = False
self.is_optimize = False self.is_optimize = False
def _enter(self, mode: Dict[str, any]): def _enter(self, mode: Dict[str, any]):
@ -231,13 +228,9 @@ class ModeState:
def update(self, *, def update(self, *,
is_train: Optional[bool] = None, is_train: Optional[bool] = None,
is_log_parameters: Optional[bool] = None,
is_log_activations: Optional[bool] = None,
is_optimize: Optional[bool] = None): is_optimize: Optional[bool] = None):
return Mode(self, return Mode(self,
is_train=is_train, is_train=is_train,
is_log_parameters=is_log_parameters,
is_log_activations=is_log_activations,
is_optimize=is_optimize) is_optimize=is_optimize)
@ -258,35 +251,6 @@ class Mode:
self.mode._exit(self.idx) self.mode._exit(self.idx)
class ForwardHook:
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
self.mode = mode
self.model_name = model_name
self.name = name
self.module = module
module.register_forward_hook(self)
def save(self, name: str, output):
if isinstance(output, torch.Tensor):
pytorch_utils.store_var(name, output)
elif isinstance(output, tuple):
for i, o in enumerate(output):
self.save(f"{name}.{i}", o)
def __call__(self, module, i, o):
if not self.mode.is_log_activations:
return
self.save(f"module.{self.model_name}.{self.name}", o)
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
for name, module in model.named_modules():
if name == '':
name = 'full'
ForwardHook(mode, model_name, name, module)
class Trainer: class Trainer:
def __init__(self, *, def __init__(self, *,
name: str, name: str,
@ -493,10 +457,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
arguments. arguments.
update_batches (int): Number of batches to accumulate before taking an optimizer step. update_batches (int): Number of batches to accumulate before taking an optimizer step.
Defaults to ``1``. Defaults to ``1``.
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
Defaults to a large number; i.e. logs every epoch.
log_activations_batches (int): How often to log model activations.
Defaults to a large number; i.e. logs every epoch.
log_save_batches (int): How often to call :func:`labml.tracker.save`. log_save_batches (int): How often to call :func:`labml.tracker.save`.
""" """
optimizer: torch.optim.Adam optimizer: torch.optim.Adam
@ -506,8 +466,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
loss_func: nn.Module loss_func: nn.Module
update_batches: int = 1 update_batches: int = 1
log_params_updates: int = 2 ** 32 # 0 if not
log_activations_batches: int = 2 ** 32 # 0 if not
log_save_batches: int = 1 log_save_batches: int = 1
state_modules: List[StateModule] = [] state_modules: List[StateModule] = []
@ -522,9 +480,7 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
with monit.section("model"): with monit.section("model"):
with self.mode.update(is_log_activations=is_log_activations):
output = self.model(data) output = self.model(data)
loss = self.loss_func(output, target) loss = self.loss_func(output, target)
@ -537,8 +493,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
if batch_idx.is_interval(self.update_batches): if batch_idx.is_interval(self.update_batches):
with monit.section('optimize'): with monit.section('optimize'):
self.optimizer.step() self.optimizer.step()
if batch_idx.is_interval(self.log_params_updates):
tracker.add('model', self.model)
self.optimizer.zero_grad() self.optimizer.zero_grad()
if batch_idx.is_interval(self.log_save_batches): if batch_idx.is_interval(self.log_save_batches):
@ -546,8 +500,7 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
meta_config(SimpleTrainValidConfigs.update_batches, meta_config(SimpleTrainValidConfigs.update_batches,
SimpleTrainValidConfigs.log_params_updates, )
SimpleTrainValidConfigs.log_activations_batches)
@option(SimpleTrainValidConfigs.optimizer) @option(SimpleTrainValidConfigs.optimizer)

View File

@ -71,8 +71,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
# Run the model and specify whether to log the activations # Run the model
with self.mode.update(is_log_activations=batch_idx.is_last):
output = self.model(data) output = self.model(data)
# Calculate the loss # Calculate the loss

View File

@ -75,8 +75,6 @@ class Configs(TransformerAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
# Get model outputs. # Get model outputs.
# It's returning a tuple for states when using RNNs. # It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜 # This is not implemented yet. 😜

View File

@ -202,8 +202,6 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get memories # Get memories
mem = self.memory.get() mem = self.memory.get()
# Run the model # Run the model

View File

@ -143,8 +143,6 @@ class Configs(NLPAutoRegressionConfigs):
with torch.no_grad(): with torch.no_grad():
data, labels = self.mlm(data) data, labels = self.mlm(data)
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get model outputs. # Get model outputs.
# It's returning a tuple for states when using RNNs. # It's returning a tuple for states when using RNNs.
# This is not implemented yet. # This is not implemented yet.

View File

@ -102,8 +102,6 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get model outputs. # Get model outputs.
output, counts, route_prob, n_dropped, route_prob_max = self.model(data) output, counts, route_prob, n_dropped, route_prob_max = self.model(data)

View File

@ -132,8 +132,6 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get memories # Get memories
mem = self.memory.get() mem = self.memory.get()
# Run the model # Run the model