mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 08:41:23 +08:00
cleanup hook model outputs
This commit is contained in:
@ -16,7 +16,7 @@ from labml.configs import option
|
|||||||
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
|
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
from labml_nn.helpers.metrics import Accuracy
|
from labml_nn.helpers.metrics import Accuracy
|
||||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
|
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||||
|
|
||||||
|
|
||||||
@ -52,8 +52,6 @@ class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
|
|||||||
# Set tracker configurations
|
# Set tracker configurations
|
||||||
tracker.set_scalar("loss.*", True)
|
tracker.set_scalar("loss.*", True)
|
||||||
tracker.set_scalar("accuracy.*", True)
|
tracker.set_scalar("accuracy.*", True)
|
||||||
# Add a hook to log module outputs
|
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
|
||||||
# Add accuracy as a state module.
|
# Add accuracy as a state module.
|
||||||
# The name is probably confusing, since it's meant to store
|
# The name is probably confusing, since it's meant to store
|
||||||
# states between training and validation for RNNs.
|
# states between training and validation for RNNs.
|
||||||
|
@ -12,16 +12,15 @@ from typing import Callable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
|
||||||
|
|
||||||
from labml import lab, monit, logger, tracker
|
from labml import lab, monit, logger, tracker
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml.logger import Text
|
from labml.logger import Text
|
||||||
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
|
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
from labml_nn.helpers.metrics import Accuracy
|
from labml_nn.helpers.metrics import Accuracy
|
||||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||||
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(nn.Module):
|
class CrossEntropyLoss(nn.Module):
|
||||||
@ -108,8 +107,6 @@ class NLPAutoRegressionConfigs(TrainValidConfigs):
|
|||||||
tracker.set_scalar("accuracy.*", True)
|
tracker.set_scalar("accuracy.*", True)
|
||||||
tracker.set_scalar("loss.*", True)
|
tracker.set_scalar("loss.*", True)
|
||||||
tracker.set_text("sampled", False)
|
tracker.set_text("sampled", False)
|
||||||
# Add a hook to log module outputs
|
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
|
||||||
# Add accuracy as a state module.
|
# Add accuracy as a state module.
|
||||||
# The name is probably confusing, since it's meant to store
|
# The name is probably confusing, since it's meant to store
|
||||||
# states between training and validation for RNNs.
|
# states between training and validation for RNNs.
|
||||||
|
@ -11,19 +11,19 @@ summary: >
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchtext
|
import torchtext
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
import torchtext.vocab
|
import torchtext.vocab
|
||||||
from torchtext.vocab import Vocab
|
from torchtext.vocab import Vocab
|
||||||
|
|
||||||
|
import torch
|
||||||
from labml import lab, tracker, monit
|
from labml import lab, tracker, monit
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
from labml_nn.helpers.metrics import Accuracy
|
from labml_nn.helpers.metrics import Accuracy
|
||||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
class NLPClassificationConfigs(TrainValidConfigs):
|
class NLPClassificationConfigs(TrainValidConfigs):
|
||||||
@ -90,8 +90,6 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
|||||||
# Set tracker configurations
|
# Set tracker configurations
|
||||||
tracker.set_scalar("accuracy.*", True)
|
tracker.set_scalar("accuracy.*", True)
|
||||||
tracker.set_scalar("loss.*", True)
|
tracker.set_scalar("loss.*", True)
|
||||||
# Add a hook to log module outputs
|
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
|
||||||
# Add accuracy as a state module.
|
# Add accuracy as a state module.
|
||||||
# The name is probably confusing, since it's meant to store
|
# The name is probably confusing, since it's meant to store
|
||||||
# states between training and validation for RNNs.
|
# states between training and validation for RNNs.
|
||||||
|
@ -9,18 +9,18 @@ summary: This experiment generates MNIST images using multi-layer perceptron.
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from labml import tracker, monit, experiment
|
from labml import tracker, monit, experiment
|
||||||
from labml.configs import option, calculate
|
from labml.configs import option, calculate
|
||||||
from labml_nn.helpers.datasets import MNISTConfigs
|
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
|
||||||
from labml_nn.helpers.optimizer import OptimizerConfigs
|
|
||||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
|
||||||
from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
||||||
|
from labml_nn.helpers.datasets import MNISTConfigs
|
||||||
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
|
from labml_nn.helpers.optimizer import OptimizerConfigs
|
||||||
|
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||||
|
|
||||||
|
|
||||||
def weights_init(m):
|
def weights_init(m):
|
||||||
@ -110,8 +110,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
"""
|
"""
|
||||||
self.state_modules = []
|
self.state_modules = []
|
||||||
|
|
||||||
hook_model_outputs(self.mode, self.generator, 'generator')
|
|
||||||
hook_model_outputs(self.mode, self.discriminator, 'discriminator')
|
|
||||||
tracker.set_scalar("loss.generator.*", True)
|
tracker.set_scalar("loss.generator.*", True)
|
||||||
tracker.set_scalar("loss.discriminator.*", True)
|
tracker.set_scalar("loss.discriminator.*", True)
|
||||||
tracker.set_image("generated", True, 1 / 100)
|
tracker.set_image("generated", True, 1 / 100)
|
||||||
@ -187,7 +185,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
"""
|
"""
|
||||||
Calculate generator loss
|
Calculate generator loss
|
||||||
"""
|
"""
|
||||||
latent = self.sample_z(batch_size)
|
latent = self.sample_z(batch_size)
|
||||||
generated_images = self.generator(latent)
|
generated_images = self.generator(latent)
|
||||||
logits = self.discriminator(generated_images)
|
logits = self.discriminator(generated_images)
|
||||||
loss = self.generator_loss(logits)
|
loss = self.generator_loss(logits)
|
||||||
@ -199,8 +197,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.dataset_transforms)
|
@option(Configs.dataset_transforms)
|
||||||
def mnist_gan_transforms():
|
def mnist_gan_transforms():
|
||||||
return transforms.Compose([
|
return transforms.Compose([
|
||||||
|
@ -32,17 +32,17 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, Tuple
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.utils.data
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
from labml import tracker, lab, monit, experiment
|
from labml import tracker, lab, monit, experiment
|
||||||
from labml.configs import BaseConfigs
|
from labml.configs import BaseConfigs
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
|
||||||
from labml_nn.helpers.trainer import ModeState, hook_model_outputs
|
|
||||||
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
|
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
|
||||||
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
|
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
|
||||||
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
|
from labml_nn.helpers.trainer import ModeState
|
||||||
from labml_nn.utils import cycle_dataloader
|
from labml_nn.utils import cycle_dataloader
|
||||||
|
|
||||||
|
|
||||||
@ -164,8 +164,6 @@ class Configs(BaseConfigs):
|
|||||||
|
|
||||||
# Training mode state for logging activations
|
# Training mode state for logging activations
|
||||||
mode: ModeState
|
mode: ModeState
|
||||||
# Whether to log model layer outputs
|
|
||||||
log_layer_outputs: bool = False
|
|
||||||
|
|
||||||
# <a id="dataset_path"></a>
|
# <a id="dataset_path"></a>
|
||||||
# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
|
# We trained this on [CelebA-HQ dataset](https://github.com/tkarras/progressive_growing_of_gans).
|
||||||
@ -199,12 +197,6 @@ class Configs(BaseConfigs):
|
|||||||
# Create path length penalty loss
|
# Create path length penalty loss
|
||||||
self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
|
self.path_length_penalty = PathLengthPenalty(0.99).to(self.device)
|
||||||
|
|
||||||
# Add model hooks to monitor layer outputs
|
|
||||||
if self.log_layer_outputs:
|
|
||||||
hook_model_outputs(self.mode, self.discriminator, 'discriminator')
|
|
||||||
hook_model_outputs(self.mode, self.generator, 'generator')
|
|
||||||
hook_model_outputs(self.mode, self.mapping_network, 'mapping_network')
|
|
||||||
|
|
||||||
# Discriminator and generator losses
|
# Discriminator and generator losses
|
||||||
self.discriminator_loss = DiscriminatorLoss().to(self.device)
|
self.discriminator_loss = DiscriminatorLoss().to(self.device)
|
||||||
self.generator_loss = GeneratorLoss().to(self.device)
|
self.generator_loss = GeneratorLoss().to(self.device)
|
||||||
|
@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.data import IterableDataset, Dataset
|
from torch.utils.data import IterableDataset, Dataset
|
||||||
|
|
||||||
|
|
||||||
def _dataset(is_train, transform):
|
def _mnist_dataset(is_train, transform):
|
||||||
return datasets.MNIST(str(lab.get_data_path()),
|
return datasets.MNIST(str(lab.get_data_path()),
|
||||||
train=is_train,
|
train=is_train,
|
||||||
download=True,
|
download=True,
|
||||||
@ -66,12 +66,12 @@ def mnist_transforms():
|
|||||||
|
|
||||||
@option(MNISTConfigs.train_dataset)
|
@option(MNISTConfigs.train_dataset)
|
||||||
def mnist_train_dataset(c: MNISTConfigs):
|
def mnist_train_dataset(c: MNISTConfigs):
|
||||||
return _dataset(True, c.dataset_transforms)
|
return _mnist_dataset(True, c.dataset_transforms)
|
||||||
|
|
||||||
|
|
||||||
@option(MNISTConfigs.valid_dataset)
|
@option(MNISTConfigs.valid_dataset)
|
||||||
def mnist_valid_dataset(c: MNISTConfigs):
|
def mnist_valid_dataset(c: MNISTConfigs):
|
||||||
return _dataset(False, c.dataset_transforms)
|
return _mnist_dataset(False, c.dataset_transforms)
|
||||||
|
|
||||||
|
|
||||||
@option(MNISTConfigs.train_loader)
|
@option(MNISTConfigs.train_loader)
|
||||||
@ -96,7 +96,7 @@ aggregate(MNISTConfigs.dataset_name, 'MNIST',
|
|||||||
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
|
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
|
||||||
|
|
||||||
|
|
||||||
def _dataset(is_train, transform):
|
def _cifar_dataset(is_train, transform):
|
||||||
return datasets.CIFAR10(str(lab.get_data_path()),
|
return datasets.CIFAR10(str(lab.get_data_path()),
|
||||||
train=is_train,
|
train=is_train,
|
||||||
download=True,
|
download=True,
|
||||||
@ -147,12 +147,12 @@ def cifar10_transforms():
|
|||||||
|
|
||||||
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
|
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
|
||||||
def cifar10_train_dataset(c: CIFAR10Configs):
|
def cifar10_train_dataset(c: CIFAR10Configs):
|
||||||
return _dataset(True, c.dataset_transforms)
|
return _cifar_dataset(True, c.dataset_transforms)
|
||||||
|
|
||||||
|
|
||||||
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
|
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
|
||||||
def cifar10_valid_dataset(c: CIFAR10Configs):
|
def cifar10_valid_dataset(c: CIFAR10Configs):
|
||||||
return _dataset(False, c.dataset_transforms)
|
return _cifar_dataset(False, c.dataset_transforms)
|
||||||
|
|
||||||
|
|
||||||
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
|
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
|
||||||
|
@ -75,43 +75,6 @@ class Accuracy(Metric):
|
|||||||
tracker.add("accuracy.", self.data.correct / self.data.samples)
|
tracker.add("accuracy.", self.data.correct / self.data.samples)
|
||||||
|
|
||||||
|
|
||||||
class AccuracyMovingAvg(Metric):
|
|
||||||
def __init__(self, ignore_index: int = -1, queue_size: int = 5):
|
|
||||||
super().__init__()
|
|
||||||
self.ignore_index = ignore_index
|
|
||||||
tracker.set_queue('accuracy.*', queue_size, is_print=True)
|
|
||||||
|
|
||||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
|
||||||
output = output.view(-1, output.shape[-1])
|
|
||||||
target = target.view(-1)
|
|
||||||
pred = output.argmax(dim=-1)
|
|
||||||
mask = target == self.ignore_index
|
|
||||||
pred.masked_fill_(mask, self.ignore_index)
|
|
||||||
n_masked = mask.sum().item()
|
|
||||||
if len(target) - n_masked > 0:
|
|
||||||
tracker.add('accuracy.', (pred.eq(target).sum().item() - n_masked) / (len(target) - n_masked))
|
|
||||||
|
|
||||||
def create_state(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def set_state(self, data: any):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_epoch_start(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_epoch_end(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BinaryAccuracy(Accuracy):
|
|
||||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
|
||||||
pred = output.view(-1) > 0
|
|
||||||
target = target.view(-1)
|
|
||||||
self.data.correct += pred.eq(target).sum().item()
|
|
||||||
self.data.samples += len(target)
|
|
||||||
|
|
||||||
|
|
||||||
class AccuracyDirect(Accuracy):
|
class AccuracyDirect(Accuracy):
|
||||||
data: AccuracyState
|
data: AccuracyState
|
||||||
|
|
||||||
|
@ -66,19 +66,15 @@ class TrainingLoop:
|
|||||||
def __init__(self, *,
|
def __init__(self, *,
|
||||||
loop_count: int,
|
loop_count: int,
|
||||||
loop_step: Optional[int],
|
loop_step: Optional[int],
|
||||||
is_save_models: bool,
|
|
||||||
log_new_line_interval: int,
|
log_new_line_interval: int,
|
||||||
log_write_interval: int,
|
log_write_interval: int,
|
||||||
save_models_interval: int,
|
|
||||||
is_loop_on_interrupt: bool):
|
is_loop_on_interrupt: bool):
|
||||||
self.__loop_count = loop_count
|
self.__loop_count = loop_count
|
||||||
self.__loop_step = loop_step
|
self.__loop_step = loop_step
|
||||||
self.__is_save_models = is_save_models
|
|
||||||
self.__log_new_line_interval = log_new_line_interval
|
self.__log_new_line_interval = log_new_line_interval
|
||||||
self.__log_write_interval = log_write_interval
|
self.__log_write_interval = log_write_interval
|
||||||
self.__last_write_step = 0
|
self.__last_write_step = 0
|
||||||
self.__last_new_line_step = 0
|
self.__last_new_line_step = 0
|
||||||
self.__save_models_interval = save_models_interval
|
|
||||||
self.__last_save_step = 0
|
self.__last_save_step = 0
|
||||||
self.__signal_received = None
|
self.__signal_received = None
|
||||||
self.__is_loop_on_interrupt = is_loop_on_interrupt
|
self.__is_loop_on_interrupt = is_loop_on_interrupt
|
||||||
@ -115,21 +111,6 @@ class TrainingLoop:
|
|||||||
pass
|
pass
|
||||||
tracker.save()
|
tracker.save()
|
||||||
tracker.new_line()
|
tracker.new_line()
|
||||||
if self.__is_save_models:
|
|
||||||
logger.log("Saving model...")
|
|
||||||
experiment.save_checkpoint()
|
|
||||||
|
|
||||||
# def is_interval(self, interval: int, global_step: Optional[int] = None):
|
|
||||||
# if global_step is None:
|
|
||||||
# global_step = tracker.get_global_step()
|
|
||||||
#
|
|
||||||
# if global_step - self.__loop_step < 0:
|
|
||||||
# return False
|
|
||||||
#
|
|
||||||
# if global_step // interval > (global_step - self.__loop_step) // interval:
|
|
||||||
# return True
|
|
||||||
# else:
|
|
||||||
# return False
|
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
if self.__signal_received is not None:
|
if self.__signal_received is not None:
|
||||||
@ -152,18 +133,6 @@ class TrainingLoop:
|
|||||||
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
|
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
|
||||||
tracker.new_line()
|
tracker.new_line()
|
||||||
self.__last_new_line_step = global_step
|
self.__last_new_line_step = global_step
|
||||||
# if self.is_interval(self.__log_write_interval, global_step):
|
|
||||||
# tracker.save()
|
|
||||||
# if self.is_interval(self.__log_new_line_interval, global_step):
|
|
||||||
# logger.log()
|
|
||||||
|
|
||||||
# if (self.__is_save_models and
|
|
||||||
# self.is_interval(self.__save_models_interval, global_step)):
|
|
||||||
# experiment.save_checkpoint()
|
|
||||||
if (self.__is_save_models and
|
|
||||||
global_step - self.__last_save_step >= self.__save_models_interval):
|
|
||||||
experiment.save_checkpoint()
|
|
||||||
self.__last_save_step = global_step
|
|
||||||
|
|
||||||
return global_step
|
return global_step
|
||||||
|
|
||||||
@ -198,9 +167,6 @@ class TrainingLoopConfigs(BaseConfigs):
|
|||||||
Arguments:
|
Arguments:
|
||||||
loop_count (int): Total number of steps. Defaults to ``10``.
|
loop_count (int): Total number of steps. Defaults to ``10``.
|
||||||
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
|
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
|
||||||
is_save_models (bool): Whether to call :func:`labml.experiment.save_checkpoint` on each iteration.
|
|
||||||
Defaults to ``False``.
|
|
||||||
save_models_interval (int): The interval (in steps) to save models. Defaults to ``1``.
|
|
||||||
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
|
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
|
||||||
Defaults to ``1``.
|
Defaults to ``1``.
|
||||||
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
|
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
|
||||||
@ -210,10 +176,8 @@ class TrainingLoopConfigs(BaseConfigs):
|
|||||||
"""
|
"""
|
||||||
loop_count: int = 10
|
loop_count: int = 10
|
||||||
loop_step: int = 1
|
loop_step: int = 1
|
||||||
is_save_models: bool = False
|
|
||||||
log_new_line_interval: int = 1
|
log_new_line_interval: int = 1
|
||||||
log_write_interval: int = 1
|
log_write_interval: int = 1
|
||||||
save_models_interval: int = 1
|
|
||||||
is_loop_on_interrupt: bool = False
|
is_loop_on_interrupt: bool = False
|
||||||
|
|
||||||
training_loop: TrainingLoop
|
training_loop: TrainingLoop
|
||||||
@ -223,19 +187,15 @@ class TrainingLoopConfigs(BaseConfigs):
|
|||||||
def _loop_configs(c: TrainingLoopConfigs):
|
def _loop_configs(c: TrainingLoopConfigs):
|
||||||
return TrainingLoop(loop_count=c.loop_count,
|
return TrainingLoop(loop_count=c.loop_count,
|
||||||
loop_step=c.loop_step,
|
loop_step=c.loop_step,
|
||||||
is_save_models=c.is_save_models,
|
|
||||||
log_new_line_interval=c.log_new_line_interval,
|
log_new_line_interval=c.log_new_line_interval,
|
||||||
log_write_interval=c.log_write_interval,
|
log_write_interval=c.log_write_interval,
|
||||||
save_models_interval=c.save_models_interval,
|
|
||||||
is_loop_on_interrupt=c.is_loop_on_interrupt)
|
is_loop_on_interrupt=c.is_loop_on_interrupt)
|
||||||
|
|
||||||
|
|
||||||
meta_config(TrainingLoopConfigs.loop_step,
|
meta_config(TrainingLoopConfigs.loop_step,
|
||||||
TrainingLoopConfigs.loop_count,
|
TrainingLoopConfigs.loop_count,
|
||||||
TrainingLoopConfigs.is_save_models,
|
|
||||||
TrainingLoopConfigs.log_new_line_interval,
|
TrainingLoopConfigs.log_new_line_interval,
|
||||||
TrainingLoopConfigs.log_write_interval,
|
TrainingLoopConfigs.log_write_interval,
|
||||||
TrainingLoopConfigs.save_models_interval,
|
|
||||||
TrainingLoopConfigs.is_loop_on_interrupt)
|
TrainingLoopConfigs.is_loop_on_interrupt)
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from labml.configs import option
|
|||||||
from labml_nn.helpers.datasets import MNISTConfigs
|
from labml_nn.helpers.datasets import MNISTConfigs
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
from labml_nn.helpers.metrics import Accuracy
|
from labml_nn.helpers.metrics import Accuracy
|
||||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
|
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||||
|
|
||||||
|
|
||||||
@ -22,6 +22,7 @@ class Model(nn.Module):
|
|||||||
"""
|
"""
|
||||||
## The model
|
## The model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
||||||
@ -60,7 +61,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
def init(self):
|
def init(self):
|
||||||
tracker.set_queue("loss.*", 20, True)
|
tracker.set_queue("loss.*", 20, True)
|
||||||
tracker.set_scalar("accuracy.*", True)
|
tracker.set_scalar("accuracy.*", True)
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
|
||||||
self.state_modules = [self.accuracy_func]
|
self.state_modules = [self.accuracy_func]
|
||||||
|
|
||||||
def step(self, batch: any, batch_idx: BatchIndex):
|
def step(self, batch: any, batch_idx: BatchIndex):
|
||||||
|
@ -41,7 +41,7 @@ import torch.nn as nn
|
|||||||
from labml import lab, experiment, tracker, monit
|
from labml import lab, experiment, tracker, monit
|
||||||
from labml_nn.helpers.device import DeviceConfigs
|
from labml_nn.helpers.device import DeviceConfigs
|
||||||
from labml_nn.helpers.optimizer import OptimizerConfigs
|
from labml_nn.helpers.optimizer import OptimizerConfigs
|
||||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
@ -530,10 +530,6 @@ class Configs(TrainValidConfigs):
|
|||||||
# Create validation data loader
|
# Create validation data loader
|
||||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||||
|
|
||||||
# Add hooks to monitor layer outputs on Tensorboard
|
|
||||||
hook_model_outputs(self.mode, self.encoder, 'encoder')
|
|
||||||
hook_model_outputs(self.mode, self.decoder, 'decoder')
|
|
||||||
|
|
||||||
# Configure the tracker to print the total train/validation loss
|
# Configure the tracker to print the total train/validation loss
|
||||||
tracker.set_scalar("loss.total.*", True)
|
tracker.set_scalar("loss.total.*", True)
|
||||||
|
|
||||||
|
@ -12,13 +12,12 @@ from typing import List, Tuple, NamedTuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from labml import experiment, tracker, monit, logger
|
from labml import experiment, tracker, monit, logger
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml.logger import Text
|
from labml.logger import Text
|
||||||
from labml_nn.helpers.metrics import SimpleStateModule
|
|
||||||
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
|
|
||||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
|
from labml_nn.helpers.metrics import SimpleStateModule
|
||||||
|
from labml_nn.helpers.trainer import BatchIndex
|
||||||
from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
|
from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
|
||||||
CompressiveTransformerLayer, Conv1dCompression
|
CompressiveTransformerLayer, Conv1dCompression
|
||||||
|
|
||||||
@ -119,8 +118,6 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
tracker.set_scalar("loss.*", True)
|
tracker.set_scalar("loss.*", True)
|
||||||
# Do not print the attention reconstruction loss in the terminal
|
# Do not print the attention reconstruction loss in the terminal
|
||||||
tracker.set_scalar("ar_loss.*", False)
|
tracker.set_scalar("ar_loss.*", False)
|
||||||
# Add a hook to log module outputs
|
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
|
||||||
# This will keep the accuracy metric stats and memories separate for training and validation.
|
# This will keep the accuracy metric stats and memories separate for training and validation.
|
||||||
self.state_modules = [self.accuracy, self.memory]
|
self.state_modules = [self.accuracy, self.memory]
|
||||||
|
|
||||||
|
@ -12,13 +12,12 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from labml.logger import Text
|
|
||||||
|
|
||||||
from labml import experiment, tracker, monit, logger
|
from labml import experiment, tracker, monit, logger
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml_nn.helpers.metrics import SimpleStateModule
|
from labml.logger import Text
|
||||||
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
|
|
||||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||||
|
from labml_nn.helpers.metrics import SimpleStateModule
|
||||||
|
from labml_nn.helpers.trainer import BatchIndex
|
||||||
from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
|
from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
|
||||||
|
|
||||||
|
|
||||||
@ -95,8 +94,6 @@ class Configs(NLPAutoRegressionConfigs):
|
|||||||
# Set tracker configurations
|
# Set tracker configurations
|
||||||
tracker.set_scalar("accuracy.*", True)
|
tracker.set_scalar("accuracy.*", True)
|
||||||
tracker.set_scalar("loss.*", True)
|
tracker.set_scalar("loss.*", True)
|
||||||
# Add a hook to log module outputs
|
|
||||||
hook_model_outputs(self.mode, self.model, 'model')
|
|
||||||
# This will keep the accuracy metric stats and memories separate for training and validation.
|
# This will keep the accuracy metric stats and memories separate for training and validation.
|
||||||
self.state_modules = [self.accuracy, self.memory]
|
self.state_modules = [self.accuracy, self.memory]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user