mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
cleanup log activations
This commit is contained in:
@ -127,10 +127,8 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Whether to log activations
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Run the model
|
||||
caps, reconstructions, pred = self.model(data)
|
||||
# Run the model
|
||||
caps, reconstructions, pred = self.model(data)
|
||||
|
||||
# Calculate the total loss
|
||||
loss = self.margin_loss(caps, target) + 0.0005 * self.reconstruction_loss(reconstructions, data)
|
||||
|
@ -73,10 +73,8 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Get model outputs.
|
||||
output = self.model(data)
|
||||
# Get model outputs.
|
||||
output = self.model(data)
|
||||
|
||||
# Calculate and log loss
|
||||
loss = self.loss_func(output, target)
|
||||
|
@ -132,12 +132,10 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet. 😜
|
||||
output, *_ = self.model(data)
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet. 😜
|
||||
output, *_ = self.model(data)
|
||||
|
||||
# Calculate and log loss
|
||||
loss = self.loss_func(output, target)
|
||||
|
@ -108,12 +108,10 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(data.shape[1])
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet. 😜
|
||||
output, *_ = self.model(data)
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet. 😜
|
||||
output, *_ = self.model(data)
|
||||
|
||||
# Calculate and log loss
|
||||
loss = self.loss_func(output, target)
|
||||
|
@ -315,38 +315,36 @@ class Configs(BaseConfigs):
|
||||
|
||||
# Accumulate gradients for `gradient_accumulate_steps`
|
||||
for i in range(self.gradient_accumulate_steps):
|
||||
# Update `mode`. Set whether to log activation
|
||||
with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
|
||||
# Sample images from generator
|
||||
generated_images, _ = self.generate_images(self.batch_size)
|
||||
# Discriminator classification for generated images
|
||||
fake_output = self.discriminator(generated_images.detach())
|
||||
# Sample images from generator
|
||||
generated_images, _ = self.generate_images(self.batch_size)
|
||||
# Discriminator classification for generated images
|
||||
fake_output = self.discriminator(generated_images.detach())
|
||||
|
||||
# Get real images from the data loader
|
||||
real_images = next(self.loader).to(self.device)
|
||||
# We need to calculate gradients w.r.t. real images for gradient penalty
|
||||
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
|
||||
real_images.requires_grad_()
|
||||
# Discriminator classification for real images
|
||||
real_output = self.discriminator(real_images)
|
||||
# Get real images from the data loader
|
||||
real_images = next(self.loader).to(self.device)
|
||||
# We need to calculate gradients w.r.t. real images for gradient penalty
|
||||
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
|
||||
real_images.requires_grad_()
|
||||
# Discriminator classification for real images
|
||||
real_output = self.discriminator(real_images)
|
||||
|
||||
# Get discriminator loss
|
||||
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
|
||||
disc_loss = real_loss + fake_loss
|
||||
# Get discriminator loss
|
||||
real_loss, fake_loss = self.discriminator_loss(real_output, fake_output)
|
||||
disc_loss = real_loss + fake_loss
|
||||
|
||||
# Add gradient penalty
|
||||
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
|
||||
# Calculate and log gradient penalty
|
||||
gp = self.gradient_penalty(real_images, real_output)
|
||||
tracker.add('loss.gp', gp)
|
||||
# Multiply by coefficient and add gradient penalty
|
||||
disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
|
||||
# Add gradient penalty
|
||||
if (idx + 1) % self.lazy_gradient_penalty_interval == 0:
|
||||
# Calculate and log gradient penalty
|
||||
gp = self.gradient_penalty(real_images, real_output)
|
||||
tracker.add('loss.gp', gp)
|
||||
# Multiply by coefficient and add gradient penalty
|
||||
disc_loss = disc_loss + 0.5 * self.gradient_penalty_coefficient * gp * self.lazy_gradient_penalty_interval
|
||||
|
||||
# Compute gradients
|
||||
disc_loss.backward()
|
||||
# Compute gradients
|
||||
disc_loss.backward()
|
||||
|
||||
# Log discriminator loss
|
||||
tracker.add('loss.discriminator', disc_loss)
|
||||
# Log discriminator loss
|
||||
tracker.add('loss.discriminator', disc_loss)
|
||||
|
||||
if (idx + 1) % self.log_generated_interval == 0:
|
||||
# Log discriminator model parameters occasionally
|
||||
|
@ -3,12 +3,11 @@ import typing
|
||||
from typing import Dict, List, Callable
|
||||
from typing import Optional, Tuple, Any, Collection
|
||||
|
||||
import labml.utils.pytorch as pytorch_utils
|
||||
import torch.optim
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data
|
||||
from labml import tracker, logger, experiment, monit
|
||||
from labml import tracker, logger, monit
|
||||
from labml.configs import BaseConfigs, meta_config, option
|
||||
from labml.internal.monitor import Loop
|
||||
from labml.logger import Text
|
||||
@ -204,8 +203,6 @@ class ModeState:
|
||||
self._rollback_stack = []
|
||||
|
||||
self.is_train = False
|
||||
self.is_log_activations = False
|
||||
self.is_log_parameters = False
|
||||
self.is_optimize = False
|
||||
|
||||
def _enter(self, mode: Dict[str, any]):
|
||||
@ -231,13 +228,9 @@ class ModeState:
|
||||
|
||||
def update(self, *,
|
||||
is_train: Optional[bool] = None,
|
||||
is_log_parameters: Optional[bool] = None,
|
||||
is_log_activations: Optional[bool] = None,
|
||||
is_optimize: Optional[bool] = None):
|
||||
return Mode(self,
|
||||
is_train=is_train,
|
||||
is_log_parameters=is_log_parameters,
|
||||
is_log_activations=is_log_activations,
|
||||
is_optimize=is_optimize)
|
||||
|
||||
|
||||
@ -258,35 +251,6 @@ class Mode:
|
||||
self.mode._exit(self.idx)
|
||||
|
||||
|
||||
class ForwardHook:
|
||||
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
|
||||
self.mode = mode
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
self.module = module
|
||||
module.register_forward_hook(self)
|
||||
|
||||
def save(self, name: str, output):
|
||||
if isinstance(output, torch.Tensor):
|
||||
pytorch_utils.store_var(name, output)
|
||||
elif isinstance(output, tuple):
|
||||
for i, o in enumerate(output):
|
||||
self.save(f"{name}.{i}", o)
|
||||
|
||||
def __call__(self, module, i, o):
|
||||
if not self.mode.is_log_activations:
|
||||
return
|
||||
|
||||
self.save(f"module.{self.model_name}.{self.name}", o)
|
||||
|
||||
|
||||
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
|
||||
for name, module in model.named_modules():
|
||||
if name == '':
|
||||
name = 'full'
|
||||
ForwardHook(mode, model_name, name, module)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, *,
|
||||
name: str,
|
||||
@ -493,10 +457,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
||||
arguments.
|
||||
update_batches (int): Number of batches to accumulate before taking an optimizer step.
|
||||
Defaults to ``1``.
|
||||
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
|
||||
Defaults to a large number; i.e. logs every epoch.
|
||||
log_activations_batches (int): How often to log model activations.
|
||||
Defaults to a large number; i.e. logs every epoch.
|
||||
log_save_batches (int): How often to call :func:`labml.tracker.save`.
|
||||
"""
|
||||
optimizer: torch.optim.Adam
|
||||
@ -506,8 +466,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
||||
loss_func: nn.Module
|
||||
|
||||
update_batches: int = 1
|
||||
log_params_updates: int = 2 ** 32 # 0 if not
|
||||
log_activations_batches: int = 2 ** 32 # 0 if not
|
||||
log_save_batches: int = 1
|
||||
|
||||
state_modules: List[StateModule] = []
|
||||
@ -522,10 +480,8 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
|
||||
with monit.section("model"):
|
||||
with self.mode.update(is_log_activations=is_log_activations):
|
||||
output = self.model(data)
|
||||
output = self.model(data)
|
||||
|
||||
loss = self.loss_func(output, target)
|
||||
tracker.add("loss.", loss)
|
||||
@ -537,8 +493,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
||||
if batch_idx.is_interval(self.update_batches):
|
||||
with monit.section('optimize'):
|
||||
self.optimizer.step()
|
||||
if batch_idx.is_interval(self.log_params_updates):
|
||||
tracker.add('model', self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if batch_idx.is_interval(self.log_save_batches):
|
||||
@ -546,8 +500,7 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
||||
|
||||
|
||||
meta_config(SimpleTrainValidConfigs.update_batches,
|
||||
SimpleTrainValidConfigs.log_params_updates,
|
||||
SimpleTrainValidConfigs.log_activations_batches)
|
||||
)
|
||||
|
||||
|
||||
@option(SimpleTrainValidConfigs.optimizer)
|
||||
|
@ -71,9 +71,8 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Run the model and specify whether to log the activations
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
output = self.model(data)
|
||||
# Run the model
|
||||
output = self.model(data)
|
||||
|
||||
# Calculate the loss
|
||||
loss = self.loss_func(output, target)
|
||||
|
@ -75,12 +75,10 @@ class Configs(TransformerAutoRegressionConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet. 😜
|
||||
output, *_ = self.model(data)
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet. 😜
|
||||
output, *_ = self.model(data)
|
||||
|
||||
# Calculate and log loss
|
||||
loss = self.loss_func(output, target)
|
||||
|
@ -202,16 +202,14 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Get memories
|
||||
mem = self.memory.get()
|
||||
# Run the model
|
||||
output, new_mem = self.model(data, mem)
|
||||
# Merge and compress memory
|
||||
mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
|
||||
# Update memories
|
||||
self.memory.set(mem)
|
||||
# Get memories
|
||||
mem = self.memory.get()
|
||||
# Run the model
|
||||
output, new_mem = self.model(data, mem)
|
||||
# Merge and compress memory
|
||||
mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
|
||||
# Update memories
|
||||
self.memory.set(mem)
|
||||
|
||||
# Calculate and log cross entropy loss
|
||||
loss = self.loss_func(output, target)
|
||||
|
@ -143,12 +143,10 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
with torch.no_grad():
|
||||
data, labels = self.mlm(data)
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet.
|
||||
output, *_ = self.model(data)
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet.
|
||||
output, *_ = self.model(data)
|
||||
|
||||
# Calculate and log the loss
|
||||
loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
|
||||
|
@ -102,10 +102,8 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Get model outputs.
|
||||
output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
|
||||
# Get model outputs.
|
||||
output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
|
||||
|
||||
# Calculate and cross entropy loss
|
||||
cross_entropy_loss = self.loss_func(output, target)
|
||||
|
@ -132,16 +132,14 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Get memories
|
||||
mem = self.memory.get()
|
||||
# Run the model
|
||||
output, new_mem = self.model(data, mem)
|
||||
# Merge memory
|
||||
mem = self.merge_memory(mem, new_mem)
|
||||
# Update memories
|
||||
self.memory.set(mem)
|
||||
# Get memories
|
||||
mem = self.memory.get()
|
||||
# Run the model
|
||||
output, new_mem = self.model(data, mem)
|
||||
# Merge memory
|
||||
mem = self.merge_memory(mem, new_mem)
|
||||
# Update memories
|
||||
self.memory.set(mem)
|
||||
|
||||
# Calculate and log cross entropy loss
|
||||
loss = self.loss_func(output, target)
|
||||
|
Reference in New Issue
Block a user