mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
cleanup hook model outputs
This commit is contained in:
@ -16,7 +16,7 @@ from labml.configs import option
|
||||
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
@ -52,8 +52,6 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# Add accuracy as a state module.
|
||||
# The name is probably confusing, since it's meant to store
|
||||
# states between training and validation for RNNs.
|
||||
|
@ -12,16 +12,15 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from labml import lab, monit, logger, tracker
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
@ -108,8 +107,6 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_text("sampled", False)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# Add accuracy as a state module.
|
||||
# The name is probably confusing, since it's meant to store
|
||||
# states between training and validation for RNNs.
|
||||
|
@ -11,19 +11,19 @@ summary: >
|
||||
from collections import Counter
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchtext
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
import torchtext.vocab
|
||||
from torchtext.vocab import Vocab
|
||||
|
||||
import torch
|
||||
from labml import lab, tracker, monit
|
||||
from labml.configs import option
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class NLPClassificationConfigs(TrainValidConfigs):
|
||||
@ -90,8 +90,6 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_scalar("loss.*", True)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# Add accuracy as a state module.
|
||||
# The name is probably confusing, since it's meant to store
|
||||
# states between training and validation for RNNs.
|
||||
|
Reference in New Issue
Block a user