cleanup hook model outputs

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:02:34 +05:30
parent 5bdedcffec
commit a713c92b82
12 changed files with 36 additions and 142 deletions

View File

@ -16,7 +16,7 @@ from labml.configs import option
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
from labml_nn.helpers.device import DeviceConfigs from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs from labml_nn.optimizers.configs import OptimizerConfigs
@ -52,8 +52,6 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
# Set tracker configurations # Set tracker configurations
tracker.set_scalar("loss.*", True) tracker.set_scalar("loss.*", True)
tracker.set_scalar("accuracy.*", True) tracker.set_scalar("accuracy.*", True)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# Add accuracy as a state module. # Add accuracy as a state module.
# The name is probably confusing, since it's meant to store # The name is probably confusing, since it's meant to store
# states between training and validation for RNNs. # states between training and validation for RNNs.

View File

@ -12,16 +12,15 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
from labml import lab, monit, logger, tracker from labml import lab, monit, logger, tracker
from labml.configs import option from labml.configs import option
from labml.logger import Text from labml.logger import Text
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
from labml_nn.helpers.device import DeviceConfigs from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs from labml_nn.optimizers.configs import OptimizerConfigs
from torch.utils.data import DataLoader, RandomSampler
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
@ -108,8 +107,6 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
tracker.set_scalar("accuracy.*", True) tracker.set_scalar("accuracy.*", True)
tracker.set_scalar("loss.*", True) tracker.set_scalar("loss.*", True)
tracker.set_text("sampled", False) tracker.set_text("sampled", False)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# Add accuracy as a state module. # Add accuracy as a state module.
# The name is probably confusing, since it's meant to store # The name is probably confusing, since it's meant to store
# states between training and validation for RNNs. # states between training and validation for RNNs.

View File

@ -11,19 +11,19 @@ summary: >
from collections import Counter from collections import Counter
from typing import Callable from typing import Callable
import torch
import torchtext import torchtext
from torch import nn
from torch.utils.data import DataLoader
import torchtext.vocab import torchtext.vocab
from torchtext.vocab import Vocab from torchtext.vocab import Vocab
import torch
from labml import lab, tracker, monit from labml import lab, tracker, monit
from labml.configs import option from labml.configs import option
from labml_nn.helpers.device import DeviceConfigs from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs from labml_nn.optimizers.configs import OptimizerConfigs
from torch import nn
from torch.utils.data import DataLoader
class NLPClassificationConfigs(TrainValidConfigs): class NLPClassificationConfigs(TrainValidConfigs):
@ -90,8 +90,6 @@ class NLPClassificationConfigs(TrainValidConfigs):
# Set tracker configurations # Set tracker configurations
tracker.set_scalar("accuracy.*", True) tracker.set_scalar("accuracy.*", True)
tracker.set_scalar("loss.*", True) tracker.set_scalar("loss.*", True)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# Add accuracy as a state module. # Add accuracy as a state module.
# The name is probably confusing, since it's meant to store # The name is probably confusing, since it's meant to store
# states between training and validation for RNNs. # states between training and validation for RNNs.

View File

@ -9,18 +9,18 @@ summary: This experiment generates MNIST images using multi-layer perceptron.
from typing import Any from typing import Any
from torchvision import transforms
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.data import torch.utils.data
from torchvision import transforms
from labml import tracker, monit, experiment from labml import tracker, monit, experiment
from labml.configs import option, calculate from labml.configs import option, calculate
from labml_nn.helpers.datasets import MNISTConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.optimizer import OptimizerConfigs
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
from labml_nn.helpers.datasets import MNISTConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.optimizer import OptimizerConfigs
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
def weights_init(m): def weights_init(m):
@ -110,8 +110,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
""" """
self.state_modules = [] self.state_modules = []
hook_model_outputs(self.mode, self.generator, 'generator')
hook_model_outputs(self.mode, self.discriminator, 'discriminator')
tracker.set_scalar("loss.generator.*", True) tracker.set_scalar("loss.generator.*", True)
tracker.set_scalar("loss.discriminator.*", True) tracker.set_scalar("loss.discriminator.*", True)
tracker.set_image("generated", True, 1 / 100) tracker.set_image("generated", True, 1 / 100)
@ -187,7 +185,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
""" """
Calculate generator loss Calculate generator loss
""" """
latent = self.sample_z(batch_size) latent = self.sample_z(batch_size)
generated_images = self.generator(latent) generated_images = self.generator(latent)
logits = self.discriminator(generated_images) logits = self.discriminator(generated_images)
loss = self.generator_loss(logits) loss = self.generator_loss(logits)
@ -199,8 +197,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
return loss return loss
@option(Configs.dataset_transforms) @option(Configs.dataset_transforms)
def mnist_gan_transforms(): def mnist_gan_transforms():
return transforms.Compose([ return transforms.Compose([

View File

@ -32,17 +32,17 @@ import math
from pathlib import Path from pathlib import Path
from typing import Iterator, Tuple from typing import Iterator, Tuple
import torch
import torch.utils.data
import torchvision import torchvision
from PIL import Image from PIL import Image
import torch
import torch.utils.data
from labml import tracker, lab, monit, experiment from labml import tracker, lab, monit, experiment
from labml.configs import BaseConfigs from labml.configs import BaseConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.trainer import ModeState, hook_model_outputs
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.trainer import ModeState
from labml_nn.utils import cycle_dataloader from labml_nn.utils import cycle_dataloader
@ -164,8 +164,6 @@ class Configs(BaseConfigs):
# Training mode state for logging activations # Training mode state for logging activations
mode: ModeState mode: ModeState
# Whether to log model layer outputs
log_layer_outputs: bool = False
# <a id="dataset_path"></a> # <a id="dataset_path"></a>
# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans). # We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
@ -199,12 +197,6 @@ class Configs(BaseConfigs):
# Create path length penalty loss # Create path length penalty loss
self.path_length_penalty = PathLengthPenalty(0.99).to(self.device) self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
# Add model hooks to monitor layer outputs
if self.log_layer_outputs:
hook_model_outputs(self.mode, self.discriminator, 'discriminator')
hook_model_outputs(self.mode, self.generator, 'generator')
hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')
# Discriminator and generator losses # Discriminator and generator losses
self.discriminator_loss = DiscriminatorLoss().to(self.device) self.discriminator_loss = DiscriminatorLoss().to(self.device)
self.generator_loss = GeneratorLoss().to(self.device) self.generator_loss = GeneratorLoss().to(self.device)

View File

@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset, Dataset from torch.utils.data import IterableDataset, Dataset
def _dataset(is_train, transform): def _mnist_dataset(is_train, transform):
return datasets.MNIST(str(lab.get_data_path()), return datasets.MNIST(str(lab.get_data_path()),
train=is_train, train=is_train,
download=True, download=True,
@ -66,12 +66,12 @@ def mnist_transforms():
@option(MNISTConfigs.train_dataset) @option(MNISTConfigs.train_dataset)
def mnist_train_dataset(c: MNISTConfigs): def mnist_train_dataset(c: MNISTConfigs):
return _dataset(True, c.dataset_transforms) return _mnist_dataset(True, c.dataset_transforms)
@option(MNISTConfigs.valid_dataset) @option(MNISTConfigs.valid_dataset)
def mnist_valid_dataset(c: MNISTConfigs): def mnist_valid_dataset(c: MNISTConfigs):
return _dataset(False, c.dataset_transforms) return _mnist_dataset(False, c.dataset_transforms)
@option(MNISTConfigs.train_loader) @option(MNISTConfigs.train_loader)
@ -96,7 +96,7 @@ aggregate(MNISTConfigs.dataset_name, 'MNIST',
(MNISTConfigs.valid_loader, 'mnist_valid_loader')) (MNISTConfigs.valid_loader, 'mnist_valid_loader'))
def _dataset(is_train, transform): def _cifar_dataset(is_train, transform):
return datasets.CIFAR10(str(lab.get_data_path()), return datasets.CIFAR10(str(lab.get_data_path()),
train=is_train, train=is_train,
download=True, download=True,
@ -147,12 +147,12 @@ def cifar10_transforms():
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset) @CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
def cifar10_train_dataset(c: CIFAR10Configs): def cifar10_train_dataset(c: CIFAR10Configs):
return _dataset(True, c.dataset_transforms) return _cifar_dataset(True, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset) @CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
def cifar10_valid_dataset(c: CIFAR10Configs): def cifar10_valid_dataset(c: CIFAR10Configs):
return _dataset(False, c.dataset_transforms) return _cifar_dataset(False, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.train_loader) @CIFAR10Configs.calc(CIFAR10Configs.train_loader)

View File

@ -75,43 +75,6 @@ class Accuracy(Metric):
tracker.add("accuracy.", self.data.correct / self.data.samples) tracker.add("accuracy.", self.data.correct / self.data.samples)
class AccuracyMovingAvg(Metric):
def __init__(self, ignore_index: int = -1, queue_size: int = 5):
super().__init__()
self.ignore_index = ignore_index
tracker.set_queue('accuracy.*', queue_size, is_print=True)
def __call__(self, output: torch.Tensor, target: torch.Tensor):
output = output.view(-1, output.shape[-1])
target = target.view(-1)
pred = output.argmax(dim=-1)
mask = target == self.ignore_index
pred.masked_fill_(mask, self.ignore_index)
n_masked = mask.sum().item()
if len(target) - n_masked > 0:
tracker.add('accuracy.', (pred.eq(target).sum().item() - n_masked) / (len(target) - n_masked))
def create_state(self):
return None
def set_state(self, data: any):
pass
def on_epoch_start(self):
pass
def on_epoch_end(self):
pass
class BinaryAccuracy(Accuracy):
def __call__(self, output: torch.Tensor, target: torch.Tensor):
pred = output.view(-1) > 0
target = target.view(-1)
self.data.correct += pred.eq(target).sum().item()
self.data.samples += len(target)
class AccuracyDirect(Accuracy): class AccuracyDirect(Accuracy):
data: AccuracyState data: AccuracyState

View File

@ -66,19 +66,15 @@ class TrainingLoop:
def __init__(self, *, def __init__(self, *,
loop_count: int, loop_count: int,
loop_step: Optional[int], loop_step: Optional[int],
is_save_models: bool,
log_new_line_interval: int, log_new_line_interval: int,
log_write_interval: int, log_write_interval: int,
save_models_interval: int,
is_loop_on_interrupt: bool): is_loop_on_interrupt: bool):
self.__loop_count = loop_count self.__loop_count = loop_count
self.__loop_step = loop_step self.__loop_step = loop_step
self.__is_save_models = is_save_models
self.__log_new_line_interval = log_new_line_interval self.__log_new_line_interval = log_new_line_interval
self.__log_write_interval = log_write_interval self.__log_write_interval = log_write_interval
self.__last_write_step = 0 self.__last_write_step = 0
self.__last_new_line_step = 0 self.__last_new_line_step = 0
self.__save_models_interval = save_models_interval
self.__last_save_step = 0 self.__last_save_step = 0
self.__signal_received = None self.__signal_received = None
self.__is_loop_on_interrupt = is_loop_on_interrupt self.__is_loop_on_interrupt = is_loop_on_interrupt
@ -115,21 +111,6 @@ class TrainingLoop:
pass pass
tracker.save() tracker.save()
tracker.new_line() tracker.new_line()
if self.__is_save_models:
logger.log("Saving model...")
experiment.save_checkpoint()
# def is_interval(self, interval: int, global_step: Optional[int] = None):
# if global_step is None:
# global_step = tracker.get_global_step()
#
# if global_step - self.__loop_step < 0:
# return False
#
# if global_step // interval > (global_step - self.__loop_step) // interval:
# return True
# else:
# return False
def __next__(self): def __next__(self):
if self.__signal_received is not None: if self.__signal_received is not None:
@ -152,18 +133,6 @@ class TrainingLoop:
if global_step - self.__last_new_line_step >= self.__log_new_line_interval: if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
tracker.new_line() tracker.new_line()
self.__last_new_line_step = global_step self.__last_new_line_step = global_step
# if self.is_interval(self.__log_write_interval, global_step):
# tracker.save()
# if self.is_interval(self.__log_new_line_interval, global_step):
# logger.log()
# if (self.__is_save_models and
# self.is_interval(self.__save_models_interval, global_step)):
# experiment.save_checkpoint()
if (self.__is_save_models and
global_step - self.__last_save_step >= self.__save_models_interval):
experiment.save_checkpoint()
self.__last_save_step = global_step
return global_step return global_step
@ -198,9 +167,6 @@ class TrainingLoopConfigs(BaseConfigs):
Arguments: Arguments:
loop_count (int): Total number of steps. Defaults to ``10``. loop_count (int): Total number of steps. Defaults to ``10``.
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``. loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
is_save_models (bool): Whether to call :func:`labml.experiment.save_checkpoint` on each iteration.
Defaults to ``False``.
save_models_interval (int): The interval (in steps) to save models. Defaults to ``1``.
log_new_line_interval (int): The interval (in steps) to print a new line to the screen. log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
Defaults to ``1``. Defaults to ``1``.
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`. log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
@ -210,10 +176,8 @@ class TrainingLoopConfigs(BaseConfigs):
""" """
loop_count: int = 10 loop_count: int = 10
loop_step: int = 1 loop_step: int = 1
is_save_models: bool = False
log_new_line_interval: int = 1 log_new_line_interval: int = 1
log_write_interval: int = 1 log_write_interval: int = 1
save_models_interval: int = 1
is_loop_on_interrupt: bool = False is_loop_on_interrupt: bool = False
training_loop: TrainingLoop training_loop: TrainingLoop
@ -223,19 +187,15 @@ class TrainingLoopConfigs(BaseConfigs):
def _loop_configs(c: TrainingLoopConfigs): def _loop_configs(c: TrainingLoopConfigs):
return TrainingLoop(loop_count=c.loop_count, return TrainingLoop(loop_count=c.loop_count,
loop_step=c.loop_step, loop_step=c.loop_step,
is_save_models=c.is_save_models,
log_new_line_interval=c.log_new_line_interval, log_new_line_interval=c.log_new_line_interval,
log_write_interval=c.log_write_interval, log_write_interval=c.log_write_interval,
save_models_interval=c.save_models_interval,
is_loop_on_interrupt=c.is_loop_on_interrupt) is_loop_on_interrupt=c.is_loop_on_interrupt)
meta_config(TrainingLoopConfigs.loop_step, meta_config(TrainingLoopConfigs.loop_step,
TrainingLoopConfigs.loop_count, TrainingLoopConfigs.loop_count,
TrainingLoopConfigs.is_save_models,
TrainingLoopConfigs.log_new_line_interval, TrainingLoopConfigs.log_new_line_interval,
TrainingLoopConfigs.log_write_interval, TrainingLoopConfigs.log_write_interval,
TrainingLoopConfigs.save_models_interval,
TrainingLoopConfigs.is_loop_on_interrupt) TrainingLoopConfigs.is_loop_on_interrupt)

View File

@ -14,7 +14,7 @@ from labml.configs import option
from labml_nn.helpers.datasets import MNISTConfigs from labml_nn.helpers.datasets import MNISTConfigs
from labml_nn.helpers.device import DeviceConfigs from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs from labml_nn.optimizers.configs import OptimizerConfigs
@ -22,6 +22,7 @@ class Model(nn.Module):
""" """
## The model ## The model
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1) self.conv1 = nn.Conv2d(1, 20, 5, 1)
@ -60,7 +61,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
def init(self): def init(self):
tracker.set_queue("loss.*", 20, True) tracker.set_queue("loss.*", 20, True)
tracker.set_scalar("accuracy.*", True) tracker.set_scalar("accuracy.*", True)
hook_model_outputs(self.mode, self.model, 'model')
self.state_modules = [self.accuracy_func] self.state_modules = [self.accuracy_func]
def step(self, batch: any, batch_idx: BatchIndex): def step(self, batch: any, batch_idx: BatchIndex):

View File

@ -41,7 +41,7 @@ import torch.nn as nn
from labml import lab, experiment, tracker, monit from labml import lab, experiment, tracker, monit
from labml_nn.helpers.device import DeviceConfigs from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.optimizer import OptimizerConfigs from labml_nn.helpers.optimizer import OptimizerConfigs
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from torch import optim from torch import optim
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
@ -530,10 +530,6 @@ class Configs(TrainValidConfigs):
# Create validation data loader # Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
# Add hooks to monitor layer outputs on Tensorboard
hook_model_outputs(self.mode, self.encoder, 'encoder')
hook_model_outputs(self.mode, self.decoder, 'decoder')
# Configure the tracker to print the total train/validation loss # Configure the tracker to print the total train/validation loss
tracker.set_scalar("loss.total.*", True) tracker.set_scalar("loss.total.*", True)

View File

@ -12,13 +12,12 @@ from typing import List, Tuple, NamedTuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from labml import experiment, tracker, monit, logger from labml import experiment, tracker, monit, logger
from labml.configs import option from labml.configs import option
from labml.logger import Text from labml.logger import Text
from labml_nn.helpers.metrics import SimpleStateModule
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.helpers.metrics import SimpleStateModule
from labml_nn.helpers.trainer import BatchIndex
from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \ from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
CompressiveTransformerLayer, Conv1dCompression CompressiveTransformerLayer, Conv1dCompression
@ -119,8 +118,6 @@ class Configs(NLPAutoRegressionConfigs):
tracker.set_scalar("loss.*", True) tracker.set_scalar("loss.*", True)
# Do not print the attention reconstruction loss in the terminal # Do not print the attention reconstruction loss in the terminal
tracker.set_scalar("ar_loss.*", False) tracker.set_scalar("ar_loss.*", False)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# This will keep the accuracy metric stats and memories separate for training and validation. # This will keep the accuracy metric stats and memories separate for training and validation.
self.state_modules = [self.accuracy, self.memory] self.state_modules = [self.accuracy, self.memory]

View File

@ -12,13 +12,12 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
from labml.logger import Text
from labml import experiment, tracker, monit, logger from labml import experiment, tracker, monit, logger
from labml.configs import option from labml.configs import option
from labml_nn.helpers.metrics import SimpleStateModule from labml.logger import Text
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.helpers.metrics import SimpleStateModule
from labml_nn.helpers.trainer import BatchIndex
from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
@ -95,8 +94,6 @@ class Configs(NLPAutoRegressionConfigs):
# Set tracker configurations # Set tracker configurations
tracker.set_scalar("accuracy.*", True) tracker.set_scalar("accuracy.*", True)
tracker.set_scalar("loss.*", True) tracker.set_scalar("loss.*", True)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# This will keep the accuracy metric stats and memories separate for training and validation. # This will keep the accuracy metric stats and memories separate for training and validation.
self.state_modules = [self.accuracy, self.memory] self.state_modules = [self.accuracy, self.memory]