cleanup log activations

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:10:05 +05:30
parent a713c92b82
commit 5eecda7e28
12 changed files with 68 additions and 136 deletions

View File

@ -127,10 +127,8 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
# Whether to log activations # Run the model
with self.mode.update(is_log_activations=batch_idx.is_last): caps, reconstructions, pred = self.model(data)
# Run the model
caps, reconstructions, pred = self.model(data)
# Calculate the total loss # Calculate the total loss
loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data) loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)

View File

@ -73,10 +73,8 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
# Whether to capture model outputs # Get model outputs.
with self.mode.update(is_log_activations=batch_idx.is_last): output = self.model(data)
# Get model outputs.
output = self.model(data)
# Calculate and log loss # Calculate and log loss
loss = self.loss_func(output, target) loss = self.loss_func(output, target)

View File

@ -132,12 +132,10 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs # Get model outputs.
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations): # It's returning a tuple for states when using RNNs.
# Get model outputs. # This is not implemented yet. 😜
# It's returning a tuple for states when using RNNs. output, *_ = self.model(data)
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Calculate and log loss # Calculate and log loss
loss = self.loss_func(output, target) loss = self.loss_func(output, target)

View File

@ -108,12 +108,10 @@ class NLPClassificationConfigs(TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[1]) tracker.add_global_step(data.shape[1])
# Whether to capture model outputs # Get model outputs.
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations): # It's returning a tuple for states when using RNNs.
# Get model outputs. # This is not implemented yet. 😜
# It's returning a tuple for states when using RNNs. output, *_ = self.model(data)
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Calculate and log loss # Calculate and log loss
loss = self.loss_func(output, target) loss = self.loss_func(output, target)

View File

@ -315,38 +315,36 @@ class Configs(BaseConfigs):
# Accumulate gradients for `gradient_accumulate_steps` # Accumulate gradients for `gradient_accumulate_steps`
for i in range(self.gradient_accumulate_steps): for i in range(self.gradient_accumulate_steps):
# Update `mode`. Set whether to log activation # Sample images from generator
with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0): generated_images, _ = self.generate_images(self.batch_size)
# Sample images from generator # Discriminator classification for generated images
generated_images, _ = self.generate_images(self.batch_size) fake_output = self.discriminator(generated_images.detach())
# Discriminator classification for generated images
fake_output = self.discriminator(generated_images.detach())
# Get real images from the data loader # Get real images from the data loader
real_images = next(self.loader).to(self.device) real_images = next(self.loader).to(self.device)
# We need to calculate gradients w.r.t. real images for gradient penalty # We need to calculate gradients w.r.t. real images for gradient penalty
if (idx + 1) % self.lazy_gradient_penalty_interval == 0: if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
real_images.requires_grad_() real_images.requires_grad_()
# Discriminator classification for real images # Discriminator classification for real images
real_output = self.discriminator(real_images) real_output = self.discriminator(real_images)
# Get discriminator loss # Get discriminator loss
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output) real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
disc_loss = real_loss + fake_loss disc_loss = real_loss + fake_loss
# Add gradient penalty # Add gradient penalty
if (idx + 1) % self.lazy_gradient_penalty_interval == 0: if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
# Calculate and log gradient penalty # Calculate and log gradient penalty
gp = self.gradient_penalty(real_images, real_output) gp = self.gradient_penalty(real_images, real_output)
tracker.add('loss.gp', gp) tracker.add('loss.gp', gp)
# Multiply by coefficient and add gradient penalty # Multiply by coefficient and add gradient penalty
disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
# Compute gradients # Compute gradients
disc_loss.backward() disc_loss.backward()
# Log discriminator loss # Log discriminator loss
tracker.add('loss.discriminator', disc_loss) tracker.add('loss.discriminator', disc_loss)
if (idx + 1) % self.log_generated_interval == 0: if (idx + 1) % self.log_generated_interval == 0:
# Log discriminator model parameters occasionally # Log discriminator model parameters occasionally

View File

@ -3,12 +3,11 @@ import typing
from typing import Dict, List, Callable from typing import Dict, List, Callable
from typing import Optional, Tuple, Any, Collection from typing import Optional, Tuple, Any, Collection
import labml.utils.pytorch as pytorch_utils
import torch.optim import torch.optim
import torch.optim import torch.optim
import torch.utils.data import torch.utils.data
import torch.utils.data import torch.utils.data
from labml import tracker, logger, experiment, monit from labml import tracker, logger, monit
from labml.configs import BaseConfigs, meta_config, option from labml.configs import BaseConfigs, meta_config, option
from labml.internal.monitor import Loop from labml.internal.monitor import Loop
from labml.logger import Text from labml.logger import Text
@ -204,8 +203,6 @@ class ModeState:
self._rollback_stack = [] self._rollback_stack = []
self.is_train = False self.is_train = False
self.is_log_activations = False
self.is_log_parameters = False
self.is_optimize = False self.is_optimize = False
def _enter(self, mode: Dict[str, any]): def _enter(self, mode: Dict[str, any]):
@ -231,13 +228,9 @@ class ModeState:
def update(self, *, def update(self, *,
is_train: Optional[bool] = None, is_train: Optional[bool] = None,
is_log_parameters: Optional[bool] = None,
is_log_activations: Optional[bool] = None,
is_optimize: Optional[bool] = None): is_optimize: Optional[bool] = None):
return Mode(self, return Mode(self,
is_train=is_train, is_train=is_train,
is_log_parameters=is_log_parameters,
is_log_activations=is_log_activations,
is_optimize=is_optimize) is_optimize=is_optimize)
@ -258,35 +251,6 @@ class Mode:
self.mode._exit(self.idx) self.mode._exit(self.idx)
class ForwardHook:
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
self.mode = mode
self.model_name = model_name
self.name = name
self.module = module
module.register_forward_hook(self)
def save(self, name: str, output):
if isinstance(output, torch.Tensor):
pytorch_utils.store_var(name, output)
elif isinstance(output, tuple):
for i, o in enumerate(output):
self.save(f"{name}.{i}", o)
def __call__(self, module, i, o):
if not self.mode.is_log_activations:
return
self.save(f"module.{self.model_name}.{self.name}", o)
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
for name, module in model.named_modules():
if name == '':
name = 'full'
ForwardHook(mode, model_name, name, module)
class Trainer: class Trainer:
def __init__(self, *, def __init__(self, *,
name: str, name: str,
@ -493,10 +457,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
arguments. arguments.
update_batches (int): Number of batches to accumulate before taking an optimizer step. update_batches (int): Number of batches to accumulate before taking an optimizer step.
Defaults to ``1``. Defaults to ``1``.
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
Defaults to a large number; i.e. logs every epoch.
log_activations_batches (int): How often to log model activations.
Defaults to a large number; i.e. logs every epoch.
log_save_batches (int): How often to call :func:`labml.tracker.save`. log_save_batches (int): How often to call :func:`labml.tracker.save`.
""" """
optimizer: torch.optim.Adam optimizer: torch.optim.Adam
@ -506,8 +466,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
loss_func: nn.Module loss_func: nn.Module
update_batches: int = 1 update_batches: int = 1
log_params_updates: int = 2 ** 32 # 0 if not
log_activations_batches: int = 2 ** 32 # 0 if not
log_save_batches: int = 1 log_save_batches: int = 1
state_modules: List[StateModule] = [] state_modules: List[StateModule] = []
@ -522,10 +480,8 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
with monit.section("model"): with monit.section("model"):
with self.mode.update(is_log_activations=is_log_activations): output = self.model(data)
output = self.model(data)
loss = self.loss_func(output, target) loss = self.loss_func(output, target)
tracker.add("loss.", loss) tracker.add("loss.", loss)
@ -537,8 +493,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
if batch_idx.is_interval(self.update_batches): if batch_idx.is_interval(self.update_batches):
with monit.section('optimize'): with monit.section('optimize'):
self.optimizer.step() self.optimizer.step()
if batch_idx.is_interval(self.log_params_updates):
tracker.add('model', self.model)
self.optimizer.zero_grad() self.optimizer.zero_grad()
if batch_idx.is_interval(self.log_save_batches): if batch_idx.is_interval(self.log_save_batches):
@ -546,8 +500,7 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
meta_config(SimpleTrainValidConfigs.update_batches, meta_config(SimpleTrainValidConfigs.update_batches,
SimpleTrainValidConfigs.log_params_updates, )
SimpleTrainValidConfigs.log_activations_batches)
@option(SimpleTrainValidConfigs.optimizer) @option(SimpleTrainValidConfigs.optimizer)

View File

@ -71,9 +71,8 @@ class Configs(MNISTConfigs, TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
# Run the model and specify whether to log the activations # Run the model
with self.mode.update(is_log_activations=batch_idx.is_last): output = self.model(data)
output = self.model(data)
# Calculate the loss # Calculate the loss
loss = self.loss_func(output, target) loss = self.loss_func(output, target)

View File

@ -75,12 +75,10 @@ class Configs(TransformerAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs # Get model outputs.
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations): # It's returning a tuple for states when using RNNs.
# Get model outputs. # This is not implemented yet. 😜
# It's returning a tuple for states when using RNNs. output, *_ = self.model(data)
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Calculate and log loss # Calculate and log loss
loss = self.loss_func(output, target) loss = self.loss_func(output, target)

View File

@ -202,16 +202,14 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs # Get memories
with self.mode.update(is_log_activations=batch_idx.is_last): mem = self.memory.get()
# Get memories # Run the model
mem = self.memory.get() output, new_mem = self.model(data, mem)
# Run the model # Merge and compress memory
output, new_mem = self.model(data, mem) mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
# Merge and compress memory # Update memories
mem, mem_to_compress = self.merge_compress_memory(mem, new_mem) self.memory.set(mem)
# Update memories
self.memory.set(mem)
# Calculate and log cross entropy loss # Calculate and log cross entropy loss
loss = self.loss_func(output, target) loss = self.loss_func(output, target)

View File

@ -143,12 +143,10 @@ class Configs(NLPAutoRegressionConfigs):
with torch.no_grad(): with torch.no_grad():
data, labels = self.mlm(data) data, labels = self.mlm(data)
# Whether to capture model outputs # Get model outputs.
with self.mode.update(is_log_activations=batch_idx.is_last): # It's returning a tuple for states when using RNNs.
# Get model outputs. # This is not implemented yet.
# It's returning a tuple for states when using RNNs. output, *_ = self.model(data)
# This is not implemented yet.
output, *_ = self.model(data)
# Calculate and log the loss # Calculate and log the loss
loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1)) loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))

View File

@ -102,10 +102,8 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs # Get model outputs.
with self.mode.update(is_log_activations=batch_idx.is_last): output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
# Get model outputs.
output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
# Calculate and cross entropy loss # Calculate and cross entropy loss
cross_entropy_loss = self.loss_func(output, target) cross_entropy_loss = self.loss_func(output, target)

View File

@ -132,16 +132,14 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train: if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1]) tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs # Get memories
with self.mode.update(is_log_activations=batch_idx.is_last): mem = self.memory.get()
# Get memories # Run the model
mem = self.memory.get() output, new_mem = self.model(data, mem)
# Run the model # Merge memory
output, new_mem = self.model(data, mem) mem = self.merge_memory(mem, new_mem)
# Merge memory # Update memories
mem = self.merge_memory(mem, new_mem) self.memory.set(mem)
# Update memories
self.memory.set(mem)
# Calculate and log cross entropy loss # Calculate and log cross entropy loss
loss = self.loss_func(output, target) loss = self.loss_func(output, target)