cleanup log activations

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:10:05 +05:30
parent a713c92b82
commit 5eecda7e28
12 changed files with 68 additions and 136 deletions

View File

@ -127,10 +127,8 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
if self.mode.is_train:
tracker.add_global_step(len(data))
# Whether to log activations
with self.mode.update(is_log_activations=batch_idx.is_last):
# Run the model
caps, reconstructions, pred = self.model(data)
# Run the model
caps, reconstructions, pred = self.model(data)
# Calculate the total loss
loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)

View File

@ -73,10 +73,8 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
if self.mode.is_train:
tracker.add_global_step(len(data))
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get model outputs.
output = self.model(data)
# Get model outputs.
output = self.model(data)
# Calculate and log loss
loss = self.loss_func(output, target)

View File

@ -132,12 +132,10 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Calculate and log loss
loss = self.loss_func(output, target)

View File

@ -108,12 +108,10 @@ class NLPClassificationConfigs(TrainValidConfigs):
if self.mode.is_train:
tracker.add_global_step(data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Calculate and log loss
loss = self.loss_func(output, target)

View File

@ -315,38 +315,36 @@ class Configs(BaseConfigs):
# Accumulate gradients for `gradient_accumulate_steps`
for i in range(self.gradient_accumulate_steps):
# Update `mode`. Set whether to log activation
with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
# Sample images from generator
generated_images, _ = self.generate_images(self.batch_size)
# Discriminator classification for generated images
fake_output = self.discriminator(generated_images.detach())
# Sample images from generator
generated_images, _ = self.generate_images(self.batch_size)
# Discriminator classification for generated images
fake_output = self.discriminator(generated_images.detach())
# Get real images from the data loader
real_images = next(self.loader).to(self.device)
# We need to calculate gradients w.r.t. real images for gradient penalty
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
real_images.requires_grad_()
# Discriminator classification for real images
real_output = self.discriminator(real_images)
# Get real images from the data loader
real_images = next(self.loader).to(self.device)
# We need to calculate gradients w.r.t. real images for gradient penalty
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
real_images.requires_grad_()
# Discriminator classification for real images
real_output = self.discriminator(real_images)
# Get discriminator loss
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
disc_loss = real_loss + fake_loss
# Get discriminator loss
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
disc_loss = real_loss + fake_loss
# Add gradient penalty
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
# Calculate and log gradient penalty
gp = self.gradient_penalty(real_images, real_output)
tracker.add('loss.gp', gp)
# Multiply by coefficient and add gradient penalty
disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
# Add gradient penalty
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
# Calculate and log gradient penalty
gp = self.gradient_penalty(real_images, real_output)
tracker.add('loss.gp', gp)
# Multiply by coefficient and add gradient penalty
disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
# Compute gradients
disc_loss.backward()
# Compute gradients
disc_loss.backward()
# Log discriminator loss
tracker.add('loss.discriminator', disc_loss)
# Log discriminator loss
tracker.add('loss.discriminator', disc_loss)
if (idx + 1) % self.log_generated_interval == 0:
# Log discriminator model parameters occasionally

View File

@ -3,12 +3,11 @@ import typing
from typing import Dict, List, Callable
from typing import Optional, Tuple, Any, Collection
import labml.utils.pytorch as pytorch_utils
import torch.optim
import torch.optim
import torch.utils.data
import torch.utils.data
from labml import tracker, logger, experiment, monit
from labml import tracker, logger, monit
from labml.configs import BaseConfigs, meta_config, option
from labml.internal.monitor import Loop
from labml.logger import Text
@ -204,8 +203,6 @@ class ModeState:
self._rollback_stack = []
self.is_train = False
self.is_log_activations = False
self.is_log_parameters = False
self.is_optimize = False
def _enter(self, mode: Dict[str, any]):
@ -231,13 +228,9 @@ class ModeState:
def update(self, *,
is_train: Optional[bool] = None,
is_log_parameters: Optional[bool] = None,
is_log_activations: Optional[bool] = None,
is_optimize: Optional[bool] = None):
return Mode(self,
is_train=is_train,
is_log_parameters=is_log_parameters,
is_log_activations=is_log_activations,
is_optimize=is_optimize)
@ -258,35 +251,6 @@ class Mode:
self.mode._exit(self.idx)
class ForwardHook:
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
self.mode = mode
self.model_name = model_name
self.name = name
self.module = module
module.register_forward_hook(self)
def save(self, name: str, output):
if isinstance(output, torch.Tensor):
pytorch_utils.store_var(name, output)
elif isinstance(output, tuple):
for i, o in enumerate(output):
self.save(f"{name}.{i}", o)
def __call__(self, module, i, o):
if not self.mode.is_log_activations:
return
self.save(f"module.{self.model_name}.{self.name}", o)
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
for name, module in model.named_modules():
if name == '':
name = 'full'
ForwardHook(mode, model_name, name, module)
class Trainer:
def __init__(self, *,
name: str,
@ -493,10 +457,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
arguments.
update_batches (int): Number of batches to accumulate before taking an optimizer step.
Defaults to ``1``.
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
Defaults to a large number; i.e. logs every epoch.
log_activations_batches (int): How often to log model activations.
Defaults to a large number; i.e. logs every epoch.
log_save_batches (int): How often to call :func:`labml.tracker.save`.
"""
optimizer: torch.optim.Adam
@ -506,8 +466,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
loss_func: nn.Module
update_batches: int = 1
log_params_updates: int = 2 ** 32 # 0 if not
log_activations_batches: int = 2 ** 32 # 0 if not
log_save_batches: int = 1
state_modules: List[StateModule] = []
@ -522,10 +480,8 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
if self.mode.is_train:
tracker.add_global_step(len(data))
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
with monit.section("model"):
with self.mode.update(is_log_activations=is_log_activations):
output = self.model(data)
output = self.model(data)
loss = self.loss_func(output, target)
tracker.add("loss.", loss)
@ -537,8 +493,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
if batch_idx.is_interval(self.update_batches):
with monit.section('optimize'):
self.optimizer.step()
if batch_idx.is_interval(self.log_params_updates):
tracker.add('model', self.model)
self.optimizer.zero_grad()
if batch_idx.is_interval(self.log_save_batches):
@ -546,8 +500,7 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
meta_config(SimpleTrainValidConfigs.update_batches,
SimpleTrainValidConfigs.log_params_updates,
SimpleTrainValidConfigs.log_activations_batches)
)
@option(SimpleTrainValidConfigs.optimizer)

View File

@ -71,9 +71,8 @@ class Configs(MNISTConfigs, TrainValidConfigs):
if self.mode.is_train:
tracker.add_global_step(len(data))
# Run the model and specify whether to log the activations
with self.mode.update(is_log_activations=batch_idx.is_last):
output = self.model(data)
# Run the model
output = self.model(data)
# Calculate the loss
loss = self.loss_func(output, target)

View File

@ -75,12 +75,10 @@ class Configs(TransformerAutoRegressionConfigs):
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet. 😜
output, *_ = self.model(data)
# Calculate and log loss
loss = self.loss_func(output, target)

View File

@ -202,16 +202,14 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get memories
mem = self.memory.get()
# Run the model
output, new_mem = self.model(data, mem)
# Merge and compress memory
mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
# Update memories
self.memory.set(mem)
# Get memories
mem = self.memory.get()
# Run the model
output, new_mem = self.model(data, mem)
# Merge and compress memory
mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
# Update memories
self.memory.set(mem)
# Calculate and log cross entropy loss
loss = self.loss_func(output, target)

View File

@ -143,12 +143,10 @@ class Configs(NLPAutoRegressionConfigs):
with torch.no_grad():
data, labels = self.mlm(data)
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet.
output, *_ = self.model(data)
# Get model outputs.
# It's returning a tuple for states when using RNNs.
# This is not implemented yet.
output, *_ = self.model(data)
# Calculate and log the loss
loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))

View File

@ -102,10 +102,8 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get model outputs.
output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
# Get model outputs.
output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
# Calculate and cross entropy loss
cross_entropy_loss = self.loss_func(output, target)

View File

@ -132,16 +132,14 @@ class Configs(NLPAutoRegressionConfigs):
if self.mode.is_train:
tracker.add_global_step(data.shape[0] * data.shape[1])
# Whether to capture model outputs
with self.mode.update(is_log_activations=batch_idx.is_last):
# Get memories
mem = self.memory.get()
# Run the model
output, new_mem = self.model(data, mem)
# Merge memory
mem = self.merge_memory(mem, new_mem)
# Update memories
self.memory.set(mem)
# Get memories
mem = self.memory.get()
# Run the model
output, new_mem = self.model(data, mem)
# Merge memory
mem = self.merge_memory(mem, new_mem)
# Update memories
self.memory.set(mem)
# Calculate and log cross entropy loss
loss = self.loss_func(output, target)