mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
cleanup hook model outputs
This commit is contained in:
@ -16,7 +16,7 @@ from labml.configs import option
|
||||
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
@ -52,8 +52,6 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# Add accuracy as a state module.
|
||||
# The name is probably confusing, since it's meant to store
|
||||
# states between training and validation for RNNs.
|
||||
|
@ -12,16 +12,15 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from labml import lab, monit, logger, tracker
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
@ -108,8 +107,6 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_text("sampled", False)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# Add accuracy as a state module.
|
||||
# The name is probably confusing, since it's meant to store
|
||||
# states between training and validation for RNNs.
|
||||
|
@ -11,19 +11,19 @@ summary: >
|
||||
from collections import Counter
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchtext
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
import torchtext.vocab
|
||||
from torchtext.vocab import Vocab
|
||||
|
||||
import torch
|
||||
from labml import lab, tracker, monit
|
||||
from labml.configs import option
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class NLPClassificationConfigs(TrainValidConfigs):
|
||||
@ -90,8 +90,6 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_scalar("loss.*", True)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# Add accuracy as a state module.
|
||||
# The name is probably confusing, since it's meant to store
|
||||
# states between training and validation for RNNs.
|
||||
|
@ -9,18 +9,18 @@ summary: This experiment generates MNIST images using multi-layer perceptron.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from torchvision import transforms
|
||||
|
||||
from labml import tracker, monit, experiment
|
||||
from labml.configs import option, calculate
|
||||
from labml_nn.helpers.datasets import MNISTConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.optimizer import OptimizerConfigs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
||||
from labml_nn.helpers.datasets import MNISTConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.optimizer import OptimizerConfigs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
@ -110,8 +110,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
"""
|
||||
self.state_modules = []
|
||||
|
||||
hook_model_outputs(self.mode, self.generator, 'generator')
|
||||
hook_model_outputs(self.mode, self.discriminator, 'discriminator')
|
||||
tracker.set_scalar("loss.generator.*", True)
|
||||
tracker.set_scalar("loss.discriminator.*", True)
|
||||
tracker.set_image("generated", True, 1 / 100)
|
||||
@ -187,7 +185,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
"""
|
||||
Calculate generator loss
|
||||
"""
|
||||
latent = self.sample_z(batch_size)
|
||||
latent = self.sample_z(batch_size)
|
||||
generated_images = self.generator(latent)
|
||||
logits = self.discriminator(generated_images)
|
||||
loss = self.generator_loss(logits)
|
||||
@ -199,8 +197,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
|
||||
@option(Configs.dataset_transforms)
|
||||
def mnist_gan_transforms():
|
||||
return transforms.Compose([
|
||||
|
@ -32,17 +32,17 @@ import math
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from labml import tracker, lab, monit, experiment
|
||||
from labml.configs import BaseConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.trainer import ModeState, hook_model_outputs
|
||||
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
|
||||
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.trainer import ModeState
|
||||
from labml_nn.utils import cycle_dataloader
|
||||
|
||||
|
||||
@ -164,8 +164,6 @@ class Configs(BaseConfigs):
|
||||
|
||||
# Training mode state for logging activations
|
||||
mode: ModeState
|
||||
# Whether to log model layer outputs
|
||||
log_layer_outputs: bool = False
|
||||
|
||||
# <a id="dataset_path"></a>
|
||||
# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
|
||||
@ -199,12 +197,6 @@ class Configs(BaseConfigs):
|
||||
# Create path length penalty loss
|
||||
self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
|
||||
|
||||
# Add model hooks to monitor layer outputs
|
||||
if self.log_layer_outputs:
|
||||
hook_model_outputs(self.mode, self.discriminator, 'discriminator')
|
||||
hook_model_outputs(self.mode, self.generator, 'generator')
|
||||
hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')
|
||||
|
||||
# Discriminator and generator losses
|
||||
self.discriminator_loss = DiscriminatorLoss().to(self.device)
|
||||
self.generator_loss = GeneratorLoss().to(self.device)
|
||||
|
@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
|
||||
|
||||
def _dataset(is_train, transform):
|
||||
def _mnist_dataset(is_train, transform):
|
||||
return datasets.MNIST(str(lab.get_data_path()),
|
||||
train=is_train,
|
||||
download=True,
|
||||
@ -66,12 +66,12 @@ def mnist_transforms():
|
||||
|
||||
@option(MNISTConfigs.train_dataset)
|
||||
def mnist_train_dataset(c: MNISTConfigs):
|
||||
return _dataset(True, c.dataset_transforms)
|
||||
return _mnist_dataset(True, c.dataset_transforms)
|
||||
|
||||
|
||||
@option(MNISTConfigs.valid_dataset)
|
||||
def mnist_valid_dataset(c: MNISTConfigs):
|
||||
return _dataset(False, c.dataset_transforms)
|
||||
return _mnist_dataset(False, c.dataset_transforms)
|
||||
|
||||
|
||||
@option(MNISTConfigs.train_loader)
|
||||
@ -96,7 +96,7 @@ aggregate(MNISTConfigs.dataset_name, 'MNIST',
|
||||
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
|
||||
|
||||
|
||||
def _dataset(is_train, transform):
|
||||
def _cifar_dataset(is_train, transform):
|
||||
return datasets.CIFAR10(str(lab.get_data_path()),
|
||||
train=is_train,
|
||||
download=True,
|
||||
@ -147,12 +147,12 @@ def cifar10_transforms():
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
|
||||
def cifar10_train_dataset(c: CIFAR10Configs):
|
||||
return _dataset(True, c.dataset_transforms)
|
||||
return _cifar_dataset(True, c.dataset_transforms)
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
|
||||
def cifar10_valid_dataset(c: CIFAR10Configs):
|
||||
return _dataset(False, c.dataset_transforms)
|
||||
return _cifar_dataset(False, c.dataset_transforms)
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
|
||||
|
@ -75,43 +75,6 @@ class Accuracy(Metric):
|
||||
tracker.add("accuracy.", self.data.correct / self.data.samples)
|
||||
|
||||
|
||||
class AccuracyMovingAvg(Metric):
|
||||
def __init__(self, ignore_index: int = -1, queue_size: int = 5):
|
||||
super().__init__()
|
||||
self.ignore_index = ignore_index
|
||||
tracker.set_queue('accuracy.*', queue_size, is_print=True)
|
||||
|
||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
||||
output = output.view(-1, output.shape[-1])
|
||||
target = target.view(-1)
|
||||
pred = output.argmax(dim=-1)
|
||||
mask = target == self.ignore_index
|
||||
pred.masked_fill_(mask, self.ignore_index)
|
||||
n_masked = mask.sum().item()
|
||||
if len(target) - n_masked > 0:
|
||||
tracker.add('accuracy.', (pred.eq(target).sum().item() - n_masked) / (len(target) - n_masked))
|
||||
|
||||
def create_state(self):
|
||||
return None
|
||||
|
||||
def set_state(self, data: any):
|
||||
pass
|
||||
|
||||
def on_epoch_start(self):
|
||||
pass
|
||||
|
||||
def on_epoch_end(self):
|
||||
pass
|
||||
|
||||
|
||||
class BinaryAccuracy(Accuracy):
|
||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
||||
pred = output.view(-1) > 0
|
||||
target = target.view(-1)
|
||||
self.data.correct += pred.eq(target).sum().item()
|
||||
self.data.samples += len(target)
|
||||
|
||||
|
||||
class AccuracyDirect(Accuracy):
|
||||
data: AccuracyState
|
||||
|
||||
|
@ -66,19 +66,15 @@ class TrainingLoop:
|
||||
def __init__(self, *,
|
||||
loop_count: int,
|
||||
loop_step: Optional[int],
|
||||
is_save_models: bool,
|
||||
log_new_line_interval: int,
|
||||
log_write_interval: int,
|
||||
save_models_interval: int,
|
||||
is_loop_on_interrupt: bool):
|
||||
self.__loop_count = loop_count
|
||||
self.__loop_step = loop_step
|
||||
self.__is_save_models = is_save_models
|
||||
self.__log_new_line_interval = log_new_line_interval
|
||||
self.__log_write_interval = log_write_interval
|
||||
self.__last_write_step = 0
|
||||
self.__last_new_line_step = 0
|
||||
self.__save_models_interval = save_models_interval
|
||||
self.__last_save_step = 0
|
||||
self.__signal_received = None
|
||||
self.__is_loop_on_interrupt = is_loop_on_interrupt
|
||||
@ -115,21 +111,6 @@ class TrainingLoop:
|
||||
pass
|
||||
tracker.save()
|
||||
tracker.new_line()
|
||||
if self.__is_save_models:
|
||||
logger.log("Saving model...")
|
||||
experiment.save_checkpoint()
|
||||
|
||||
# def is_interval(self, interval: int, global_step: Optional[int] = None):
|
||||
# if global_step is None:
|
||||
# global_step = tracker.get_global_step()
|
||||
#
|
||||
# if global_step - self.__loop_step < 0:
|
||||
# return False
|
||||
#
|
||||
# if global_step // interval > (global_step - self.__loop_step) // interval:
|
||||
# return True
|
||||
# else:
|
||||
# return False
|
||||
|
||||
def __next__(self):
|
||||
if self.__signal_received is not None:
|
||||
@ -152,18 +133,6 @@ class TrainingLoop:
|
||||
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
|
||||
tracker.new_line()
|
||||
self.__last_new_line_step = global_step
|
||||
# if self.is_interval(self.__log_write_interval, global_step):
|
||||
# tracker.save()
|
||||
# if self.is_interval(self.__log_new_line_interval, global_step):
|
||||
# logger.log()
|
||||
|
||||
# if (self.__is_save_models and
|
||||
# self.is_interval(self.__save_models_interval, global_step)):
|
||||
# experiment.save_checkpoint()
|
||||
if (self.__is_save_models and
|
||||
global_step - self.__last_save_step >= self.__save_models_interval):
|
||||
experiment.save_checkpoint()
|
||||
self.__last_save_step = global_step
|
||||
|
||||
return global_step
|
||||
|
||||
@ -198,9 +167,6 @@ class TrainingLoopConfigs(BaseConfigs):
|
||||
Arguments:
|
||||
loop_count (int): Total number of steps. Defaults to ``10``.
|
||||
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
|
||||
is_save_models (bool): Whether to call :func:`labml.experiment.save_checkpoint` on each iteration.
|
||||
Defaults to ``False``.
|
||||
save_models_interval (int): The interval (in steps) to save models. Defaults to ``1``.
|
||||
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
|
||||
Defaults to ``1``.
|
||||
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
|
||||
@ -210,10 +176,8 @@ class TrainingLoopConfigs(BaseConfigs):
|
||||
"""
|
||||
loop_count: int = 10
|
||||
loop_step: int = 1
|
||||
is_save_models: bool = False
|
||||
log_new_line_interval: int = 1
|
||||
log_write_interval: int = 1
|
||||
save_models_interval: int = 1
|
||||
is_loop_on_interrupt: bool = False
|
||||
|
||||
training_loop: TrainingLoop
|
||||
@ -223,19 +187,15 @@ class TrainingLoopConfigs(BaseConfigs):
|
||||
def _loop_configs(c: TrainingLoopConfigs):
|
||||
return TrainingLoop(loop_count=c.loop_count,
|
||||
loop_step=c.loop_step,
|
||||
is_save_models=c.is_save_models,
|
||||
log_new_line_interval=c.log_new_line_interval,
|
||||
log_write_interval=c.log_write_interval,
|
||||
save_models_interval=c.save_models_interval,
|
||||
is_loop_on_interrupt=c.is_loop_on_interrupt)
|
||||
|
||||
|
||||
meta_config(TrainingLoopConfigs.loop_step,
|
||||
TrainingLoopConfigs.loop_count,
|
||||
TrainingLoopConfigs.is_save_models,
|
||||
TrainingLoopConfigs.log_new_line_interval,
|
||||
TrainingLoopConfigs.log_write_interval,
|
||||
TrainingLoopConfigs.save_models_interval,
|
||||
TrainingLoopConfigs.is_loop_on_interrupt)
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ from labml.configs import option
|
||||
from labml_nn.helpers.datasets import MNISTConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ class Model(nn.Module):
|
||||
"""
|
||||
## The model
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
||||
@ -60,7 +61,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
def init(self):
|
||||
tracker.set_queue("loss.*", 20, True)
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
self.state_modules = [self.accuracy_func]
|
||||
|
||||
def step(self, batch: any, batch_idx: BatchIndex):
|
||||
|
@ -41,7 +41,7 @@ import torch.nn as nn
|
||||
from labml import lab, experiment, tracker, monit
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.optimizer import OptimizerConfigs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||
from torch import optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
@ -530,10 +530,6 @@ class Configs(TrainValidConfigs):
|
||||
# Create validation data loader
|
||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||
|
||||
# Add hooks to monitor layer outputs on Tensorboard
|
||||
hook_model_outputs(self.mode, self.encoder, 'encoder')
|
||||
hook_model_outputs(self.mode, self.decoder, 'decoder')
|
||||
|
||||
# Configure the tracker to print the total train/validation loss
|
||||
tracker.set_scalar("loss.total.*", True)
|
||||
|
||||
|
@ -12,13 +12,12 @@ from typing import List, Tuple, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from labml import experiment, tracker, monit, logger
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml_nn.helpers.metrics import SimpleStateModule
|
||||
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
from labml_nn.helpers.metrics import SimpleStateModule
|
||||
from labml_nn.helpers.trainer import BatchIndex
|
||||
from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
|
||||
CompressiveTransformerLayer, Conv1dCompression
|
||||
|
||||
@ -119,8 +118,6 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
tracker.set_scalar("loss.*", True)
|
||||
# Do not print the attention reconstruction loss in the terminal
|
||||
tracker.set_scalar("ar_loss.*", False)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# This will keep the accuracy metric stats and memories separate for training and validation.
|
||||
self.state_modules = [self.accuracy, self.memory]
|
||||
|
||||
|
@ -12,13 +12,12 @@ from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from labml.logger import Text
|
||||
|
||||
from labml import experiment, tracker, monit, logger
|
||||
from labml.configs import option
|
||||
from labml_nn.helpers.metrics import SimpleStateModule
|
||||
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
|
||||
from labml.logger import Text
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
from labml_nn.helpers.metrics import SimpleStateModule
|
||||
from labml_nn.helpers.trainer import BatchIndex
|
||||
from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
|
||||
|
||||
|
||||
@ -95,8 +94,6 @@ class Configs(NLPAutoRegressionConfigs):
|
||||
# Set tracker configurations
|
||||
tracker.set_scalar("accuracy.*", True)
|
||||
tracker.set_scalar("loss.*", True)
|
||||
# Add a hook to log module outputs
|
||||
hook_model_outputs(self.mode, self.model, 'model')
|
||||
# This will keep the accuracy metric stats and memories separate for training and validation.
|
||||
self.state_modules = [self.accuracy, self.memory]
|
||||
|
||||
|
Reference in New Issue
Block a user