mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-16 10:51:23 +08:00
cleanup log activations
This commit is contained in:
@ -127,8 +127,6 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(len(data))
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
# Whether to log activations
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
|
||||||
# Run the model
|
# Run the model
|
||||||
caps, reconstructions, pred = self.model(data)
|
caps, reconstructions, pred = self.model(data)
|
||||||
|
|
||||||
|
@ -73,8 +73,6 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(len(data))
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
|
||||||
# Get model outputs.
|
# Get model outputs.
|
||||||
output = self.model(data)
|
output = self.model(data)
|
||||||
|
|
||||||
|
@ -132,8 +132,6 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
|
|
||||||
# Get model outputs.
|
# Get model outputs.
|
||||||
# It's returning a tuple for states when using RNNs.
|
# It's returning a tuple for states when using RNNs.
|
||||||
# This is not implemented yet. 😜
|
# This is not implemented yet. 😜
|
||||||
|
@ -108,8 +108,6 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(data.shape[1])
|
tracker.add_global_step(data.shape[1])
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
|
|
||||||
# Get model outputs.
|
# Get model outputs.
|
||||||
# It's returning a tuple for states when using RNNs.
|
# It's returning a tuple for states when using RNNs.
|
||||||
# This is not implemented yet. 😜
|
# This is not implemented yet. 😜
|
||||||
|
@ -315,8 +315,6 @@ class Configs(BaseConfigs):
|
|||||||
|
|
||||||
# Accumulate gradients for `gradient_accumulate_steps`
|
# Accumulate gradients for `gradient_accumulate_steps`
|
||||||
for i in range(self.gradient_accumulate_steps):
|
for i in range(self.gradient_accumulate_steps):
|
||||||
# Update `mode`. Set whether to log activation
|
|
||||||
with self.mode.update(is_log_activations=(idx + 1) % self.log_generated_interval == 0):
|
|
||||||
# Sample images from generator
|
# Sample images from generator
|
||||||
generated_images, _ = self.generate_images(self.batch_size)
|
generated_images, _ = self.generate_images(self.batch_size)
|
||||||
# Discriminator classification for generated images
|
# Discriminator classification for generated images
|
||||||
|
@ -3,12 +3,11 @@ import typing
|
|||||||
from typing import Dict, List, Callable
|
from typing import Dict, List, Callable
|
||||||
from typing import Optional, Tuple, Any, Collection
|
from typing import Optional, Tuple, Any, Collection
|
||||||
|
|
||||||
import labml.utils.pytorch as pytorch_utils
|
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.optim
|
import torch.optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from labml import tracker, logger, experiment, monit
|
from labml import tracker, logger, monit
|
||||||
from labml.configs import BaseConfigs, meta_config, option
|
from labml.configs import BaseConfigs, meta_config, option
|
||||||
from labml.internal.monitor import Loop
|
from labml.internal.monitor import Loop
|
||||||
from labml.logger import Text
|
from labml.logger import Text
|
||||||
@ -204,8 +203,6 @@ class ModeState:
|
|||||||
self._rollback_stack = []
|
self._rollback_stack = []
|
||||||
|
|
||||||
self.is_train = False
|
self.is_train = False
|
||||||
self.is_log_activations = False
|
|
||||||
self.is_log_parameters = False
|
|
||||||
self.is_optimize = False
|
self.is_optimize = False
|
||||||
|
|
||||||
def _enter(self, mode: Dict[str, any]):
|
def _enter(self, mode: Dict[str, any]):
|
||||||
@ -231,13 +228,9 @@ class ModeState:
|
|||||||
|
|
||||||
def update(self, *,
|
def update(self, *,
|
||||||
is_train: Optional[bool] = None,
|
is_train: Optional[bool] = None,
|
||||||
is_log_parameters: Optional[bool] = None,
|
|
||||||
is_log_activations: Optional[bool] = None,
|
|
||||||
is_optimize: Optional[bool] = None):
|
is_optimize: Optional[bool] = None):
|
||||||
return Mode(self,
|
return Mode(self,
|
||||||
is_train=is_train,
|
is_train=is_train,
|
||||||
is_log_parameters=is_log_parameters,
|
|
||||||
is_log_activations=is_log_activations,
|
|
||||||
is_optimize=is_optimize)
|
is_optimize=is_optimize)
|
||||||
|
|
||||||
|
|
||||||
@ -258,35 +251,6 @@ class Mode:
|
|||||||
self.mode._exit(self.idx)
|
self.mode._exit(self.idx)
|
||||||
|
|
||||||
|
|
||||||
class ForwardHook:
|
|
||||||
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
|
|
||||||
self.mode = mode
|
|
||||||
self.model_name = model_name
|
|
||||||
self.name = name
|
|
||||||
self.module = module
|
|
||||||
module.register_forward_hook(self)
|
|
||||||
|
|
||||||
def save(self, name: str, output):
|
|
||||||
if isinstance(output, torch.Tensor):
|
|
||||||
pytorch_utils.store_var(name, output)
|
|
||||||
elif isinstance(output, tuple):
|
|
||||||
for i, o in enumerate(output):
|
|
||||||
self.save(f"{name}.{i}", o)
|
|
||||||
|
|
||||||
def __call__(self, module, i, o):
|
|
||||||
if not self.mode.is_log_activations:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.save(f"module.{self.model_name}.{self.name}", o)
|
|
||||||
|
|
||||||
|
|
||||||
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if name == '':
|
|
||||||
name = 'full'
|
|
||||||
ForwardHook(mode, model_name, name, module)
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(self, *,
|
def __init__(self, *,
|
||||||
name: str,
|
name: str,
|
||||||
@ -493,10 +457,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
|||||||
arguments.
|
arguments.
|
||||||
update_batches (int): Number of batches to accumulate before taking an optimizer step.
|
update_batches (int): Number of batches to accumulate before taking an optimizer step.
|
||||||
Defaults to ``1``.
|
Defaults to ``1``.
|
||||||
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
|
|
||||||
Defaults to a large number; i.e. logs every epoch.
|
|
||||||
log_activations_batches (int): How often to log model activations.
|
|
||||||
Defaults to a large number; i.e. logs every epoch.
|
|
||||||
log_save_batches (int): How often to call :func:`labml.tracker.save`.
|
log_save_batches (int): How often to call :func:`labml.tracker.save`.
|
||||||
"""
|
"""
|
||||||
optimizer: torch.optim.Adam
|
optimizer: torch.optim.Adam
|
||||||
@ -506,8 +466,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
|||||||
loss_func: nn.Module
|
loss_func: nn.Module
|
||||||
|
|
||||||
update_batches: int = 1
|
update_batches: int = 1
|
||||||
log_params_updates: int = 2 ** 32 # 0 if not
|
|
||||||
log_activations_batches: int = 2 ** 32 # 0 if not
|
|
||||||
log_save_batches: int = 1
|
log_save_batches: int = 1
|
||||||
|
|
||||||
state_modules: List[StateModule] = []
|
state_modules: List[StateModule] = []
|
||||||
@ -522,9 +480,7 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(len(data))
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
|
|
||||||
with monit.section("model"):
|
with monit.section("model"):
|
||||||
with self.mode.update(is_log_activations=is_log_activations):
|
|
||||||
output = self.model(data)
|
output = self.model(data)
|
||||||
|
|
||||||
loss = self.loss_func(output, target)
|
loss = self.loss_func(output, target)
|
||||||
@ -537,8 +493,6 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
|||||||
if batch_idx.is_interval(self.update_batches):
|
if batch_idx.is_interval(self.update_batches):
|
||||||
with monit.section('optimize'):
|
with monit.section('optimize'):
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
if batch_idx.is_interval(self.log_params_updates):
|
|
||||||
tracker.add('model', self.model)
|
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
if batch_idx.is_interval(self.log_save_batches):
|
if batch_idx.is_interval(self.log_save_batches):
|
||||||
@ -546,8 +500,7 @@ class SimpleTrainValidConfigs(TrainValidConfigs):
|
|||||||
|
|
||||||
|
|
||||||
meta_config(SimpleTrainValidConfigs.update_batches,
|
meta_config(SimpleTrainValidConfigs.update_batches,
|
||||||
SimpleTrainValidConfigs.log_params_updates,
|
)
|
||||||
SimpleTrainValidConfigs.log_activations_batches)
|
|
||||||
|
|
||||||
|
|
||||||
@option(SimpleTrainValidConfigs.optimizer)
|
@option(SimpleTrainValidConfigs.optimizer)
|
||||||
|
@ -71,8 +71,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(len(data))
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
# Run the model and specify whether to log the activations
|
# Run the model
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
|
||||||
output = self.model(data)
|
output = self.model(data)
|
||||||
|
|
||||||
# Calculate the loss
|
# Calculate the loss
|
||||||
|
@ -75,8 +75,6 @@ class Configs(TransformerAutoRegressionConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
|
|
||||||
# Get model outputs.
|
# Get model outputs.
|
||||||
# It's returning a tuple for states when using RNNs.
|
# It's returning a tuple for states when using RNNs.
|
||||||
# This is not implemented yet. 😜
|
# This is not implemented yet. 😜
|
||||||
|
@ -202,8 +202,6 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
|
||||||
# Get memories
|
# Get memories
|
||||||
mem = self.memory.get()
|
mem = self.memory.get()
|
||||||
# Run the model
|
# Run the model
|
||||||
|
@ -143,8 +143,6 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
data, labels = self.mlm(data)
|
data, labels = self.mlm(data)
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
|
||||||
# Get model outputs.
|
# Get model outputs.
|
||||||
# It's returning a tuple for states when using RNNs.
|
# It's returning a tuple for states when using RNNs.
|
||||||
# This is not implemented yet.
|
# This is not implemented yet.
|
||||||
|
@ -102,8 +102,6 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
|
||||||
# Get model outputs.
|
# Get model outputs.
|
||||||
output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
|
output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
|
||||||
|
|
||||||
|
@ -132,8 +132,6 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
tracker.add_global_step(data.shape[0] * data.shape[1])
|
tracker.add_global_step(data.shape[0] * data.shape[1])
|
||||||
|
|
||||||
# Whether to capture model outputs
|
|
||||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
|
||||||
# Get memories
|
# Get memories
|
||||||
mem = self.memory.get()
|
mem = self.memory.get()
|
||||||
# Run the model
|
# Run the model
|
||||||
|
Reference in New Issue
Block a user