cleanup hook model outputs

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:02:34 +05:30
parent 5bdedcffec
commit a713c92b82
12 changed files with 36 additions and 142 deletions

View File

@ -16,7 +16,7 @@ from labml.configs import option
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs
@ -52,8 +52,6 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
# Set tracker configurations
tracker.set_scalar("loss.*", True)
tracker.set_scalar("accuracy.*", True)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# Add accuracy as a state module.
# The name is probably confusing, since it's meant to store
# states between training and validation for RNNs.

View File

@ -12,16 +12,15 @@ from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
from labml import lab, monit, logger, tracker
from labml.configs import option
from labml.logger import Text
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs
from torch.utils.data import DataLoader, RandomSampler
class CrossEntropyLoss(nn.Module):
@ -108,8 +107,6 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
tracker.set_scalar("accuracy.*", True)
tracker.set_scalar("loss.*", True)
tracker.set_text("sampled", False)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# Add accuracy as a state module.
# The name is probably confusing, since it's meant to store
# states between training and validation for RNNs.

View File

@ -11,19 +11,19 @@ summary: >
from collections import Counter
from typing import Callable
import torch
import torchtext
from torch import nn
from torch.utils.data import DataLoader
import torchtext.vocab
from torchtext.vocab import Vocab
import torch
from labml import lab, tracker, monit
from labml.configs import option
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs
from torch import nn
from torch.utils.data import DataLoader
class NLPClassificationConfigs(TrainValidConfigs):
@ -90,8 +90,6 @@ class NLPClassificationConfigs(TrainValidConfigs):
# Set tracker configurations
tracker.set_scalar("accuracy.*", True)
tracker.set_scalar("loss.*", True)
# Add a hook to log module outputs
hook_model_outputs(self.mode, self.model, 'model')
# Add accuracy as a state module.
# The name is probably confusing, since it's meant to store
# states between training and validation for RNNs.