From 5bdedcffec669dd6244d1f8c5f4e5c764b6e183c Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sun, 20 Jul 2025 08:56:03 +0530 Subject: [PATCH] remove labml_helpers dep --- .../ponder_net/experiment.py | 6 +- labml_nn/capsule_networks/mnist.py | 7 +- labml_nn/diffusion/ddpm/experiment.py | 4 +- labml_nn/distillation/__init__.py | 2 +- labml_nn/experiments/cifar10.py | 5 +- labml_nn/experiments/mnist.py | 8 +- labml_nn/experiments/nlp_autoregression.py | 8 +- labml_nn/experiments/nlp_classification.py | 6 +- labml_nn/gan/cycle_gan/__init__.py | 2 +- labml_nn/gan/original/experiment.py | 8 +- labml_nn/gan/stylegan/experiment.py | 6 +- labml_nn/graphs/gat/experiment.py | 2 +- labml_nn/helpers/__init__.py | 0 labml_nn/helpers/datasets.py | 322 ++++++++++ labml_nn/helpers/device.py | 70 ++ labml_nn/helpers/metrics.py | 122 ++++ labml_nn/helpers/optimizer.py | 97 +++ labml_nn/helpers/schedule.py | 84 +++ labml_nn/helpers/trainer.py | 598 ++++++++++++++++++ labml_nn/lora/experiment.py | 2 +- labml_nn/optimizers/mnist_experiment.py | 11 +- labml_nn/optimizers/performance_test.py | 2 +- labml_nn/rl/dqn/experiment.py | 2 +- labml_nn/sampling/experiment_tiny.py | 2 +- labml_nn/sketch_rnn/__init__.py | 6 +- labml_nn/transformers/alibi/experiment.py | 2 +- labml_nn/transformers/basic/with_sophia.py | 2 +- labml_nn/transformers/compressive/__init__.py | 3 +- .../transformers/compressive/experiment.py | 4 +- labml_nn/transformers/mlm/experiment.py | 4 +- labml_nn/transformers/retro/database.py | 2 +- labml_nn/transformers/retro/dataset.py | 2 +- labml_nn/transformers/retro/train.py | 2 +- labml_nn/transformers/switch/experiment.py | 2 +- labml_nn/transformers/xl/experiment.py | 4 +- labml_nn/uncertainty/evidence/experiment.py | 6 +- labml_nn/unet/experiment.py | 4 +- labml_nn/utils/__init__.py | 7 +- setup.py | 1 - 39 files changed, 1356 insertions(+), 71 deletions(-) create mode 100644 labml_nn/helpers/__init__.py create mode 100644 labml_nn/helpers/datasets.py create mode 100644 labml_nn/helpers/device.py create mode 100644 labml_nn/helpers/metrics.py create mode 100644 labml_nn/helpers/optimizer.py create mode 100644 labml_nn/helpers/schedule.py create mode 100644 labml_nn/helpers/trainer.py diff --git a/labml_nn/adaptive_computation/ponder_net/experiment.py b/labml_nn/adaptive_computation/ponder_net/experiment.py index 8917002d..c264cacb 100644 --- a/labml_nn/adaptive_computation/ponder_net/experiment.py +++ b/labml_nn/adaptive_computation/ponder_net/experiment.py @@ -17,8 +17,8 @@ from torch import nn from torch.utils.data import DataLoader from labml import tracker, experiment -from labml_helpers.metrics.accuracy import AccuracyDirect -from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex +from labml_nn.helpers.metrics import AccuracyDirect +from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex from labml_nn.adaptive_computation.parity import ParityDataset from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLoss @@ -26,7 +26,7 @@ from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, Reconstruc class Configs(SimpleTrainValidConfigs): """ Configurations with a - [simple training loop](https://docs.labml.ai/api/helpers.html#labml_helpers.train_valid.SimpleTrainValidConfigs) + [simple training loop](../../helpers/trainer.html) """ # Number of epochs diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py index c84d565c..45d3b61b 100644 --- a/labml_nn/capsule_networks/mnist.py +++ b/labml_nn/capsule_networks/mnist.py @@ -16,13 +16,12 @@ from typing import Any import torch.nn as nn import torch.nn.functional as F import torch.utils.data - from labml import experiment, tracker from labml.configs import option -from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.metrics.accuracy import AccuracyDirect -from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_nn.capsule_networks import Squash, Router, MarginLoss +from labml_nn.helpers.datasets import MNISTConfigs +from labml_nn.helpers.metrics import AccuracyDirect +from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex class MNISTCapsuleNetworkModel(nn.Module): diff --git a/labml_nn/diffusion/ddpm/experiment.py b/labml_nn/diffusion/ddpm/experiment.py index d882f100..a185522c 100644 --- a/labml_nn/diffusion/ddpm/experiment.py +++ b/labml_nn/diffusion/ddpm/experiment.py @@ -26,7 +26,7 @@ from PIL import Image from labml import lab, tracker, experiment, monit from labml.configs import BaseConfigs, option -from labml_helpers.device import DeviceConfigs +from labml_nn.helpers.device import DeviceConfigs from labml_nn.diffusion.ddpm import DenoiseDiffusion from labml_nn.diffusion.ddpm.unet import UNet @@ -36,7 +36,7 @@ class Configs(BaseConfigs): ## Configurations """ # Device to train the model on. - # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs) + # [`DeviceConfigs`](../../device.html) # picks up an available CUDA device or defaults to CPU. device: torch.device = DeviceConfigs() diff --git a/labml_nn/distillation/__init__.py b/labml_nn/distillation/__init__.py index 082dcbd3..a8d0d11b 100644 --- a/labml_nn/distillation/__init__.py +++ b/labml_nn/distillation/__init__.py @@ -75,7 +75,7 @@ from torch import nn from labml import experiment, tracker from labml.configs import option -from labml_helpers.train_valid import BatchIndex +from labml_nn.helpers.trainer import BatchIndex from labml_nn.distillation.large import LargeModel from labml_nn.distillation.small import SmallModel from labml_nn.experiments.cifar10 import CIFAR10Configs diff --git a/labml_nn/experiments/cifar10.py b/labml_nn/experiments/cifar10.py index c82b3d3d..e9b57e3f 100644 --- a/labml_nn/experiments/cifar10.py +++ b/labml_nn/experiments/cifar10.py @@ -13,7 +13,7 @@ import torch.nn as nn from labml import lab from labml.configs import option -from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs +from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs from labml_nn.experiments.mnist import MNISTConfigs @@ -21,8 +21,7 @@ class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs): """ ## Configurations - This extends from CIFAR 10 dataset configurations from - [`labml_helpers`](https://github.com/labmlai/labml/tree/master/helpers) + This extends from [CIFAR 10 dataset configurations](../helpers/datasets.html) and [`MNISTConfigs`](mnist.html). """ # Use CIFAR10 dataset by default diff --git a/labml_nn/experiments/mnist.py b/labml_nn/experiments/mnist.py index 63e2f501..0245427d 100644 --- a/labml_nn/experiments/mnist.py +++ b/labml_nn/experiments/mnist.py @@ -13,10 +13,10 @@ import torch.utils.data from labml import tracker from labml.configs import option -from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs -from labml_helpers.device import DeviceConfigs -from labml_helpers.metrics.accuracy import Accuracy -from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs +from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs +from labml_nn.helpers.device import DeviceConfigs +from labml_nn.helpers.metrics import Accuracy +from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs from labml_nn.optimizers.configs import OptimizerConfigs diff --git a/labml_nn/experiments/nlp_autoregression.py b/labml_nn/experiments/nlp_autoregression.py index 914b1ef2..647f8e16 100644 --- a/labml_nn/experiments/nlp_autoregression.py +++ b/labml_nn/experiments/nlp_autoregression.py @@ -17,10 +17,10 @@ from torch.utils.data import DataLoader, RandomSampler from labml import lab, monit, logger, tracker from labml.configs import option from labml.logger import Text -from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset -from labml_helpers.device import DeviceConfigs -from labml_helpers.metrics.accuracy import Accuracy -from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex +from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset +from labml_nn.helpers.device import DeviceConfigs +from labml_nn.helpers.metrics import Accuracy +from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.optimizers.configs import OptimizerConfigs diff --git a/labml_nn/experiments/nlp_classification.py b/labml_nn/experiments/nlp_classification.py index 76f70598..39fc9394 100644 --- a/labml_nn/experiments/nlp_classification.py +++ b/labml_nn/experiments/nlp_classification.py @@ -20,9 +20,9 @@ from torchtext.vocab import Vocab from labml import lab, tracker, monit from labml.configs import option -from labml_helpers.device import DeviceConfigs -from labml_helpers.metrics.accuracy import Accuracy -from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex +from labml_nn.helpers.device import DeviceConfigs +from labml_nn.helpers.metrics import Accuracy +from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.optimizers.configs import OptimizerConfigs diff --git a/labml_nn/gan/cycle_gan/__init__.py b/labml_nn/gan/cycle_gan/__init__.py index e34b6501..a09823c7 100644 --- a/labml_nn/gan/cycle_gan/__init__.py +++ b/labml_nn/gan/cycle_gan/__init__.py @@ -49,7 +49,7 @@ from labml import lab, tracker, experiment, monit from labml.configs import BaseConfigs from labml.utils.download import download_file from labml.utils.pytorch import get_modules -from labml_helpers.device import DeviceConfigs +from labml_nn.helpers.device import DeviceConfigs class GeneratorResNet(nn.Module): diff --git a/labml_nn/gan/original/experiment.py b/labml_nn/gan/original/experiment.py index 9330d5ee..4a8d6b43 100644 --- a/labml_nn/gan/original/experiment.py +++ b/labml_nn/gan/original/experiment.py @@ -16,10 +16,10 @@ from torchvision import transforms from labml import tracker, monit, experiment from labml.configs import option, calculate -from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.device import DeviceConfigs -from labml_helpers.optimizer import OptimizerConfigs -from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex +from labml_nn.helpers.datasets import MNISTConfigs +from labml_nn.helpers.device import DeviceConfigs +from labml_nn.helpers.optimizer import OptimizerConfigs +from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss diff --git a/labml_nn/gan/stylegan/experiment.py b/labml_nn/gan/stylegan/experiment.py index 7dfa9448..195961ee 100644 --- a/labml_nn/gan/stylegan/experiment.py +++ b/labml_nn/gan/stylegan/experiment.py @@ -39,8 +39,8 @@ from PIL import Image from labml import tracker, lab, monit, experiment from labml.configs import BaseConfigs -from labml_helpers.device import DeviceConfigs -from labml_helpers.train_valid import ModeState, hook_model_outputs +from labml_nn.helpers.device import DeviceConfigs +from labml_nn.helpers.trainer import ModeState, hook_model_outputs from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss from labml_nn.utils import cycle_dataloader @@ -88,7 +88,7 @@ class Configs(BaseConfigs): """ # Device to train the model on. - # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs) + # [`DeviceConfigs`](../../helpers/device.html) # picks up an available CUDA device or defaults to CPU. device: torch.device = DeviceConfigs() diff --git a/labml_nn/graphs/gat/experiment.py b/labml_nn/graphs/gat/experiment.py index 09dad477..1f080bd9 100644 --- a/labml_nn/graphs/gat/experiment.py +++ b/labml_nn/graphs/gat/experiment.py @@ -17,7 +17,7 @@ from torch import nn from labml import lab, monit, tracker, experiment from labml.configs import BaseConfigs, option, calculate from labml.utils import download -from labml_helpers.device import DeviceConfigs +from labml_nn.helpers.device import DeviceConfigs from labml_nn.graphs.gat import GraphAttentionLayer from labml_nn.optimizers.configs import OptimizerConfigs diff --git a/labml_nn/helpers/__init__.py b/labml_nn/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/labml_nn/helpers/datasets.py b/labml_nn/helpers/datasets.py new file mode 100644 index 00000000..4fbe331a --- /dev/null +++ b/labml_nn/helpers/datasets.py @@ -0,0 +1,322 @@ +import random +from pathlib import PurePath, Path +from typing import List, Callable, Dict, Optional + +from torchvision import datasets, transforms + +import torch +from labml import lab +from labml import monit +from labml.configs import BaseConfigs +from labml.configs import aggregate, option +from labml.utils.download import download_file +from torch.utils.data import DataLoader +from torch.utils.data import IterableDataset, Dataset + + +def _dataset(is_train, transform): + return datasets.MNIST(str(lab.get_data_path()), + train=is_train, + download=True, + transform=transform) + + +class MNISTConfigs(BaseConfigs): + """ + Configurable MNIST data set. + + Arguments: + dataset_name (str): name of the data set, ``MNIST`` + dataset_transforms (torchvision.transforms.Compose): image transformations + train_dataset (torchvision.datasets.MNIST): training dataset + valid_dataset (torchvision.datasets.MNIST): validation dataset + + train_loader (torch.utils.data.DataLoader): training data loader + valid_loader (torch.utils.data.DataLoader): validation data loader + + train_batch_size (int): training batch size + valid_batch_size (int): validation batch size + + train_loader_shuffle (bool): whether to shuffle training data + valid_loader_shuffle (bool): whether to shuffle validation data + """ + + dataset_name: str = 'MNIST' + dataset_transforms: transforms.Compose + train_dataset: datasets.MNIST + valid_dataset: datasets.MNIST + + train_loader: DataLoader + valid_loader: DataLoader + + train_batch_size: int = 64 + valid_batch_size: int = 1024 + + train_loader_shuffle: bool = True + valid_loader_shuffle: bool = False + + +@option(MNISTConfigs.dataset_transforms) +def mnist_transforms(): + return transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + +@option(MNISTConfigs.train_dataset) +def mnist_train_dataset(c: MNISTConfigs): + return _dataset(True, c.dataset_transforms) + + +@option(MNISTConfigs.valid_dataset) +def mnist_valid_dataset(c: MNISTConfigs): + return _dataset(False, c.dataset_transforms) + + +@option(MNISTConfigs.train_loader) +def mnist_train_loader(c: MNISTConfigs): + return DataLoader(c.train_dataset, + batch_size=c.train_batch_size, + shuffle=c.train_loader_shuffle) + + +@option(MNISTConfigs.valid_loader) +def mnist_valid_loader(c: MNISTConfigs): + return DataLoader(c.valid_dataset, + batch_size=c.valid_batch_size, + shuffle=c.valid_loader_shuffle) + + +aggregate(MNISTConfigs.dataset_name, 'MNIST', + (MNISTConfigs.dataset_transforms, 'mnist_transforms'), + (MNISTConfigs.train_dataset, 'mnist_train_dataset'), + (MNISTConfigs.valid_dataset, 'mnist_valid_dataset'), + (MNISTConfigs.train_loader, 'mnist_train_loader'), + (MNISTConfigs.valid_loader, 'mnist_valid_loader')) + + +def _dataset(is_train, transform): + return datasets.CIFAR10(str(lab.get_data_path()), + train=is_train, + download=True, + transform=transform) + + +class CIFAR10Configs(BaseConfigs): + """ + Configurable CIFAR 10 data set. + + Arguments: + dataset_name (str): name of the data set, ``CIFAR10`` + dataset_transforms (torchvision.transforms.Compose): image transformations + train_dataset (torchvision.datasets.CIFAR10): training dataset + valid_dataset (torchvision.datasets.CIFAR10): validation dataset + + train_loader (torch.utils.data.DataLoader): training data loader + valid_loader (torch.utils.data.DataLoader): validation data loader + + train_batch_size (int): training batch size + valid_batch_size (int): validation batch size + + train_loader_shuffle (bool): whether to shuffle training data + valid_loader_shuffle (bool): whether to shuffle validation data + """ + dataset_name: str = 'CIFAR10' + dataset_transforms: transforms.Compose + train_dataset: datasets.CIFAR10 + valid_dataset: datasets.CIFAR10 + + train_loader: DataLoader + valid_loader: DataLoader + + train_batch_size: int = 64 + valid_batch_size: int = 1024 + + train_loader_shuffle: bool = True + valid_loader_shuffle: bool = False + + +@CIFAR10Configs.calc(CIFAR10Configs.dataset_transforms) +def cifar10_transforms(): + return transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + +@CIFAR10Configs.calc(CIFAR10Configs.train_dataset) +def cifar10_train_dataset(c: CIFAR10Configs): + return _dataset(True, c.dataset_transforms) + + +@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset) +def cifar10_valid_dataset(c: CIFAR10Configs): + return _dataset(False, c.dataset_transforms) + + +@CIFAR10Configs.calc(CIFAR10Configs.train_loader) +def cifar10_train_loader(c: CIFAR10Configs): + return DataLoader(c.train_dataset, + batch_size=c.train_batch_size, + shuffle=c.train_loader_shuffle) + + +@CIFAR10Configs.calc(CIFAR10Configs.valid_loader) +def cifar10_valid_loader(c: CIFAR10Configs): + return DataLoader(c.valid_dataset, + batch_size=c.valid_batch_size, + shuffle=c.valid_loader_shuffle) + + +CIFAR10Configs.aggregate(CIFAR10Configs.dataset_name, 'CIFAR10', + (CIFAR10Configs.dataset_transforms, 'cifar10_transforms'), + (CIFAR10Configs.train_dataset, 'cifar10_train_dataset'), + (CIFAR10Configs.valid_dataset, 'cifar10_valid_dataset'), + (CIFAR10Configs.train_loader, 'cifar10_train_loader'), + (CIFAR10Configs.valid_loader, 'cifar10_valid_loader')) + + +class TextDataset: + itos: List[str] + stoi: Dict[str, int] + n_tokens: int + train: str + valid: str + standard_tokens: List[str] = [] + + @staticmethod + def load(path: PurePath): + with open(str(path), 'r') as f: + return f.read() + + def __init__(self, path: PurePath, tokenizer: Callable, train: str, valid: str, test: str, *, + n_tokens: Optional[int] = None, + stoi: Optional[Dict[str, int]] = None, + itos: Optional[List[str]] = None): + self.test = test + self.valid = valid + self.train = train + self.tokenizer = tokenizer + self.path = path + + if n_tokens or stoi or itos: + assert stoi and itos and n_tokens + self.n_tokens = n_tokens + self.stoi = stoi + self.itos = itos + else: + self.n_tokens = len(self.standard_tokens) + self.stoi = {t: i for i, t in enumerate(self.standard_tokens)} + + with monit.section("Tokenize"): + tokens = self.tokenizer(self.train) + self.tokenizer(self.valid) + tokens = sorted(list(set(tokens))) + + for t in monit.iterate("Build vocabulary", tokens): + self.stoi[t] = self.n_tokens + self.n_tokens += 1 + + self.itos = [''] * self.n_tokens + for t, n in self.stoi.items(): + self.itos[n] = t + + def text_to_i(self, text: str) -> torch.Tensor: + tokens = self.tokenizer(text) + return torch.tensor([self.stoi[s] for s in tokens if s in self.stoi], dtype=torch.long) + + def __repr__(self): + return f'{len(self.train) / 1_000_000 :,.2f}M, {len(self.valid) / 1_000_000 :,.2f}M - {str(self.path)}' + + +class SequentialDataLoader(IterableDataset): + def __init__(self, *, text: str, dataset: TextDataset, + batch_size: int, seq_len: int): + self.seq_len = seq_len + data = dataset.text_to_i(text) + n_batch = data.shape[0] // batch_size + data = data.narrow(0, 0, n_batch * batch_size) + data = data.view(batch_size, -1).t().contiguous() + self.data = data + + def __len__(self): + return self.data.shape[0] // self.seq_len + + def __iter__(self): + self.idx = 0 + return self + + def __next__(self): + if self.idx >= self.data.shape[0] - 1: + raise StopIteration() + + seq_len = min(self.seq_len, self.data.shape[0] - 1 - self.idx) + i = self.idx + seq_len + data = self.data[self.idx: i] + target = self.data[self.idx + 1: i + 1] + self.idx = i + return data, target + + def __getitem__(self, idx): + seq_len = min(self.seq_len, self.data.shape[0] - 1 - idx) + i = idx + seq_len + data = self.data[idx: i] + target = self.data[idx + 1: i + 1] + return data, target + + +class SequentialUnBatchedDataset(Dataset): + def __init__(self, *, text: str, dataset: TextDataset, + seq_len: int, + is_random_offset: bool = True): + self.is_random_offset = is_random_offset + self.seq_len = seq_len + self.data = dataset.text_to_i(text) + + def __len__(self): + return (self.data.shape[0] - 1) // self.seq_len + + def __getitem__(self, idx): + start = idx * self.seq_len + assert start + self.seq_len + 1 <= self.data.shape[0] + if self.is_random_offset: + start += random.randint(0, min(self.seq_len - 1, self.data.shape[0] - (start + self.seq_len + 1))) + + end = start + self.seq_len + data = self.data[start: end] + target = self.data[start + 1: end + 1] + return data, target + + +class TextFileDataset(TextDataset): + standard_tokens = [] + + def __init__(self, path: PurePath, tokenizer: Callable, *, + url: Optional[str] = None, + filter_subset: Optional[int] = None): + path = Path(path) + if not path.exists(): + if not url: + raise FileNotFoundError(str(path)) + else: + download_file(url, path) + + with monit.section("Load data"): + text = self.load(path) + if filter_subset: + text = text[:filter_subset] + split = int(len(text) * .9) + train = text[:split] + valid = text[split:] + + super().__init__(path, tokenizer, train, valid, '') + + +def _test_tiny_shakespeare(): + from labml import lab + _ = TextFileDataset(lab.get_data_path() / 'tiny_shakespeare.txt', lambda x: list(x), + url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt') + + +if __name__ == '__main__': + _test_tiny_shakespeare() diff --git a/labml_nn/helpers/device.py b/labml_nn/helpers/device.py new file mode 100644 index 00000000..dbee21d8 --- /dev/null +++ b/labml_nn/helpers/device.py @@ -0,0 +1,70 @@ +import torch + +from labml.configs import BaseConfigs, hyperparams, option + + +class DeviceInfo: + def __init__(self, *, + use_cuda: bool, + cuda_device: int): + self.use_cuda = use_cuda + self.cuda_device = cuda_device + self.cuda_count = torch.cuda.device_count() + + self.is_cuda = self.use_cuda and torch.cuda.is_available() + if not self.is_cuda: + self.device = torch.device('cpu') + else: + if self.cuda_device < self.cuda_count: + self.device = torch.device('cuda', self.cuda_device) + else: + self.device = torch.device('cuda', self.cuda_count - 1) + + def __str__(self): + if not self.is_cuda: + return "CPU" + + if self.cuda_device < self.cuda_count: + return f"GPU:{self.cuda_device} - {torch.cuda.get_device_name(self.cuda_device)}" + else: + return (f"GPU:{self.cuda_count - 1}({self.cuda_device}) " + f"- {torch.cuda.get_device_name(self.cuda_count - 1)}") + + +class DeviceConfigs(BaseConfigs): + r""" + This is a configurable module to get a single device to train model on. + It can pick up CUDA devices and it will fall back to CPU if they are not available. + + It has other small advantages such as being able to view the + actual device name on configurations view of + `labml app `_ + + Arguments: + cuda_device (int): The CUDA device number. Defaults to ``0``. + use_cuda (bool): Whether to use CUDA devices. Defaults to ``True``. + """ + cuda_device: int = 0 + use_cuda: bool = True + + device_info: DeviceInfo + + device: torch.device + + def __init__(self): + super().__init__(_primary='device') + + +@option(DeviceConfigs.device) +def _device(c: DeviceConfigs): + return c.device_info.device + + +hyperparams(DeviceConfigs.cuda_device, DeviceConfigs.use_cuda, + is_hyperparam=False) + + +@option(DeviceConfigs.device_info) +def _device_info(c: DeviceConfigs): + return DeviceInfo(use_cuda=c.use_cuda, + cuda_device=c.cuda_device) diff --git a/labml_nn/helpers/metrics.py b/labml_nn/helpers/metrics.py new file mode 100644 index 00000000..e07e2e1d --- /dev/null +++ b/labml_nn/helpers/metrics.py @@ -0,0 +1,122 @@ +import dataclasses +from abc import ABC + +import torch +from labml import tracker + + +class StateModule: + def __init__(self): + pass + + # def __call__(self): + # raise NotImplementedError + + def create_state(self) -> any: + raise NotImplementedError + + def set_state(self, data: any): + raise NotImplementedError + + def on_epoch_start(self): + raise NotImplementedError + + def on_epoch_end(self): + raise NotImplementedError + + +class Metric(StateModule, ABC): + def track(self): + pass + + +@dataclasses.dataclass +class AccuracyState: + samples: int = 0 + correct: int = 0 + + def reset(self): + self.samples = 0 + self.correct = 0 + + +class Accuracy(Metric): + data: AccuracyState + + def __init__(self, ignore_index: int = -1): + super().__init__() + self.ignore_index = ignore_index + + def __call__(self, output: torch.Tensor, target: torch.Tensor): + output = output.view(-1, output.shape[-1]) + target = target.view(-1) + pred = output.argmax(dim=-1) + mask = target == self.ignore_index + pred.masked_fill_(mask, self.ignore_index) + n_masked = mask.sum().item() + self.data.correct += pred.eq(target).sum().item() - n_masked + self.data.samples += len(target) - n_masked + + def create_state(self): + return AccuracyState() + + def set_state(self, data: any): + self.data = data + + def on_epoch_start(self): + self.data.reset() + + def on_epoch_end(self): + self.track() + + def track(self): + if self.data.samples == 0: + return + tracker.add("accuracy.", self.data.correct / self.data.samples) + + +class AccuracyMovingAvg(Metric): + def __init__(self, ignore_index: int = -1, queue_size: int = 5): + super().__init__() + self.ignore_index = ignore_index + tracker.set_queue('accuracy.*', queue_size, is_print=True) + + def __call__(self, output: torch.Tensor, target: torch.Tensor): + output = output.view(-1, output.shape[-1]) + target = target.view(-1) + pred = output.argmax(dim=-1) + mask = target == self.ignore_index + pred.masked_fill_(mask, self.ignore_index) + n_masked = mask.sum().item() + if len(target) - n_masked > 0: + tracker.add('accuracy.', (pred.eq(target).sum().item() - n_masked) / (len(target) - n_masked)) + + def create_state(self): + return None + + def set_state(self, data: any): + pass + + def on_epoch_start(self): + pass + + def on_epoch_end(self): + pass + + +class BinaryAccuracy(Accuracy): + def __call__(self, output: torch.Tensor, target: torch.Tensor): + pred = output.view(-1) > 0 + target = target.view(-1) + self.data.correct += pred.eq(target).sum().item() + self.data.samples += len(target) + + +class AccuracyDirect(Accuracy): + data: AccuracyState + + def __call__(self, output: torch.Tensor, target: torch.Tensor): + output = output.view(-1) + target = target.view(-1) + self.data.correct += output.eq(target).sum().item() + self.data.samples += len(target) diff --git a/labml_nn/helpers/optimizer.py b/labml_nn/helpers/optimizer.py new file mode 100644 index 00000000..485e6768 --- /dev/null +++ b/labml_nn/helpers/optimizer.py @@ -0,0 +1,97 @@ +from typing import Tuple + +import torch +from labml import tracker + +from labml.configs import BaseConfigs, option, meta_config + + +class OptimizerConfigs(BaseConfigs): + r""" + This creates a configurable optimizer. + + Arguments: + learning_rate (float): Learning rate of the optimizer. Defaults to ``0.01``. + momentum (float): Momentum of the optimizer. Defaults to ``0.5``. + parameters: Model parameters to optimize. + d_model (int): Embedding size of the model (for Noam optimizer). + betas (Tuple[float, float]): Betas for Adam optimizer. Defaults to ``(0.9, 0.999)``. + eps (float): Epsilon for Adam/RMSProp optimizers. Defaults to ``1e-8``. + step_factor (int): Step factor for Noam optimizer. Defaults to ``1024``. + + Also there is a better (more options) implementation in ``labml_nn``. + `We recommend using that `_. + """ + + optimizer: torch.optim.Adam + learning_rate: float = 0.01 + momentum: float = 0.5 + parameters: any + d_model: int + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + step_factor: int = 1024 + + def __init__(self): + super().__init__(_primary='optimizer') + + +meta_config(OptimizerConfigs.parameters) + + +@option(OptimizerConfigs.optimizer, 'SGD') +def sgd_optimizer(c: OptimizerConfigs): + return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum) + + +@option(OptimizerConfigs.optimizer, 'Adam') +def adam_optimizer(c: OptimizerConfigs): + return torch.optim.Adam(c.parameters, lr=c.learning_rate, + betas=c.betas, eps=c.eps) + + +class NoamOpt: + def __init__(self, model_size: int, learning_rate: float, warmup: int, step_factor: int, optimizer): + self.step_factor = step_factor + self.optimizer = optimizer + self.warmup = warmup + self.learning_rate = learning_rate + self.model_size = model_size + self._rate = 0 + + def step(self): + rate = self.rate(tracker.get_global_step() / self.step_factor) + for p in self.optimizer.param_groups: + p['lr'] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step): + factor = self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)) + return self.learning_rate * factor + + def zero_grad(self): + self.optimizer.zero_grad() + + +@option(OptimizerConfigs.optimizer, 'Noam') +def noam_optimizer(c: OptimizerConfigs): + optimizer = torch.optim.Adam(c.parameters, lr=0.0, betas=c.betas, eps=c.eps) + return NoamOpt(c.d_model, 1, 2000, c.step_factor, optimizer) + + +def _test_noam_optimizer(): + import matplotlib.pyplot as plt + import numpy as np + + opts = [NoamOpt(512, 1, 4000, None), + NoamOpt(512, 1, 8000, None), + NoamOpt(2048, 1, 2000, None)] + plt.plot(np.arange(1, 20000), [[opt.rate(i) for opt in opts] for i in range(1, 20000)]) + plt.legend(["512:4000", "512:8000", "256:4000"]) + plt.title("Optimizer") + plt.show() + + +if __name__ == '__main__': + _test_noam_optimizer() diff --git a/labml_nn/helpers/schedule.py b/labml_nn/helpers/schedule.py new file mode 100644 index 00000000..5d0e8cd5 --- /dev/null +++ b/labml_nn/helpers/schedule.py @@ -0,0 +1,84 @@ +from typing import Tuple, List + + +class Schedule: + def __call__(self, x): + raise NotImplementedError() + + +class Flat(Schedule): + def __init__(self, value): + self.__value = value + + def __call__(self, x): + return self.__value + + def __str__(self): + return f"Schedule({self.__value})" + + +class Dynamic(Schedule): + def __init__(self, value): + self.__value = value + + def __call__(self, x): + return self.__value + + def update(self, value): + self.__value = value + + def __str__(self): + return "Dynamic" + + +class Piecewise(Schedule): + """ + ## Piecewise schedule + """ + + def __init__(self, endpoints: List[Tuple[float, float]], outside_value: float = None): + """ + ### Initialize + + `endpoints` is list of pairs `(x, y)`. + The values between endpoints are linearly interpolated. + `y` values outside the range covered by `x` are + `outside_value`. + """ + + # `(x, y)` pairs should be sorted + indexes = [e[0] for e in endpoints] + assert indexes == sorted(indexes) + + self._outside_value = outside_value + self._endpoints = endpoints + + def __call__(self, x): + """ + ### Find `y` for given `x` + """ + + # iterate through each segment + for (x1, y1), (x2, y2) in zip(self._endpoints[:-1], self._endpoints[1:]): + # interpolate if `x` is within the segment + if x1 <= x < x2: + dx = float(x - x1) / (x2 - x1) + return y1 + dx * (y2 - y1) + + # return outside value otherwise + return self._outside_value + + def __str__(self): + endpoints = ", ".join([f"({e[0]}, {e[1]})" for e in self._endpoints]) + return f"Schedule[{endpoints}, {self._outside_value}]" + + +class RelativePiecewise(Piecewise): + def __init__(self, relative_endpoits: List[Tuple[float, float]], total_steps: int): + endpoints = [] + for e in relative_endpoits: + index = int(total_steps * e[0]) + assert index >= 0 + endpoints.append((index, e[1])) + + super().__init__(endpoints, outside_value=relative_endpoits[-1][1]) diff --git a/labml_nn/helpers/trainer.py b/labml_nn/helpers/trainer.py new file mode 100644 index 00000000..a2f9b335 --- /dev/null +++ b/labml_nn/helpers/trainer.py @@ -0,0 +1,598 @@ +import signal +import typing +from typing import Dict, List, Callable +from typing import Optional, Tuple, Any, Collection + +import labml.utils.pytorch as pytorch_utils +import torch.optim +import torch.optim +import torch.utils.data +import torch.utils.data +from labml import tracker, logger, experiment, monit +from labml.configs import BaseConfigs, meta_config, option +from labml.internal.monitor import Loop +from labml.logger import Text +from torch import nn +from .device import DeviceConfigs +from .metrics import StateModule + + +class TrainingLoopIterator(Collection): + def __init__(self, start: int, total: int, step: Optional[int]): + self.step = step + self.total = total + self.start = start + self.i = None + + def __iter__(self): + self.i = None + return self + + def __next__(self): + if self.step is not None: + if self.i is None: + self.i = self.start + else: + self.i += self.step + else: + if self.i is None: + self.i = 0 + else: + self.i += 1 + + if self.i >= self.total: + raise StopIteration() + + if self.step is None: + return tracker.get_global_step() + else: + return self.i + + def __len__(self) -> int: + if self.step is not None: + return (self.total - self.start) // self.step + else: + return self.total + + def __contains__(self, x: object) -> bool: + return False + + +class TrainingLoop: + _iter: Optional[TrainingLoopIterator] + __loop: Loop + __signal_received: Optional[Tuple[Any, Any]] + + def __init__(self, *, + loop_count: int, + loop_step: Optional[int], + is_save_models: bool, + log_new_line_interval: int, + log_write_interval: int, + save_models_interval: int, + is_loop_on_interrupt: bool): + self.__loop_count = loop_count + self.__loop_step = loop_step + self.__is_save_models = is_save_models + self.__log_new_line_interval = log_new_line_interval + self.__log_write_interval = log_write_interval + self.__last_write_step = 0 + self.__last_new_line_step = 0 + self.__save_models_interval = save_models_interval + self.__last_save_step = 0 + self.__signal_received = None + self.__is_loop_on_interrupt = is_loop_on_interrupt + self._iter = None + + def __iter__(self): + self._iter = TrainingLoopIterator(tracker.get_global_step(), + self.__loop_count, + self.__loop_step) + + self.__loop = monit.loop(typing.cast(Collection, self._iter)) + + iter(self.__loop) + try: + self.old_handler = signal.signal(signal.SIGINT, self.__handler) + except ValueError: + pass + return self + + @property + def idx(self): + if not self._iter: + return 0 + if not self._iter.i: + return 0 + if self.__loop_step is None: + return self._iter.i + return self._iter.i / self.__loop_step + + def __finish(self): + try: + signal.signal(signal.SIGINT, self.old_handler) + except ValueError: + pass + tracker.save() + tracker.new_line() + if self.__is_save_models: + logger.log("Saving model...") + experiment.save_checkpoint() + + # def is_interval(self, interval: int, global_step: Optional[int] = None): + # if global_step is None: + # global_step = tracker.get_global_step() + # + # if global_step - self.__loop_step < 0: + # return False + # + # if global_step // interval > (global_step - self.__loop_step) // interval: + # return True + # else: + # return False + + def __next__(self): + if self.__signal_received is not None: + logger.log('\nKilling Loop.', Text.danger) + monit.finish_loop() + self.__finish() + raise StopIteration("SIGINT") + + try: + global_step = next(self.__loop) + except StopIteration as e: + self.__finish() + raise e + + tracker.set_global_step(global_step) + + if global_step - self.__last_write_step >= self.__log_write_interval: + tracker.save() + self.__last_write_step = global_step + if global_step - self.__last_new_line_step >= self.__log_new_line_interval: + tracker.new_line() + self.__last_new_line_step = global_step + # if self.is_interval(self.__log_write_interval, global_step): + # tracker.save() + # if self.is_interval(self.__log_new_line_interval, global_step): + # logger.log() + + # if (self.__is_save_models and + # self.is_interval(self.__save_models_interval, global_step)): + # experiment.save_checkpoint() + if (self.__is_save_models and + global_step - self.__last_save_step >= self.__save_models_interval): + experiment.save_checkpoint() + self.__last_save_step = global_step + + return global_step + + def __handler(self, sig, frame): + # Pass second interrupt without delaying + if self.__signal_received is not None: + logger.log('\nSIGINT received twice. Stopping...', Text.danger) + self.old_handler(*self.__signal_received) + return + + if self.__is_loop_on_interrupt: + # Store the interrupt signal for later + self.__signal_received = (sig, frame) + logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger) + else: + self.__finish() + logger.log('Killing loop...', Text.danger) + self.old_handler(sig, frame) + + def __str__(self): + return "LabTrainingLoop" + + +class TrainingLoopConfigs(BaseConfigs): + r""" + This is a configurable training loop. You can extend this class for your configurations + if it involves a training loop. + + >>> for step in conf.training_loop: + >>> ... + + Arguments: + loop_count (int): Total number of steps. Defaults to ``10``. + loop_step (int): Number of steps to increment per iteration. Defaults to ``1``. + is_save_models (bool): Whether to call :func:`labml.experiment.save_checkpoint` on each iteration. + Defaults to ``False``. + save_models_interval (int): The interval (in steps) to save models. Defaults to ``1``. + log_new_line_interval (int): The interval (in steps) to print a new line to the screen. + Defaults to ``1``. + log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`. + Defaults to ``1``. + is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete. + Defaults to ``False``. + """ + loop_count: int = 10 + loop_step: int = 1 + is_save_models: bool = False + log_new_line_interval: int = 1 + log_write_interval: int = 1 + save_models_interval: int = 1 + is_loop_on_interrupt: bool = False + + training_loop: TrainingLoop + + +@option(TrainingLoopConfigs.training_loop) +def _loop_configs(c: TrainingLoopConfigs): + return TrainingLoop(loop_count=c.loop_count, + loop_step=c.loop_step, + is_save_models=c.is_save_models, + log_new_line_interval=c.log_new_line_interval, + log_write_interval=c.log_write_interval, + save_models_interval=c.save_models_interval, + is_loop_on_interrupt=c.is_loop_on_interrupt) + + +meta_config(TrainingLoopConfigs.loop_step, + TrainingLoopConfigs.loop_count, + TrainingLoopConfigs.is_save_models, + TrainingLoopConfigs.log_new_line_interval, + TrainingLoopConfigs.log_write_interval, + TrainingLoopConfigs.save_models_interval, + TrainingLoopConfigs.is_loop_on_interrupt) + + +class ModeState: + def __init__(self): + self._rollback_stack = [] + + self.is_train = False + self.is_log_activations = False + self.is_log_parameters = False + self.is_optimize = False + + def _enter(self, mode: Dict[str, any]): + rollback = {} + for k, v in mode.items(): + if v is None: + continue + rollback[k] = getattr(self, k) + setattr(self, k, v) + + self._rollback_stack.append(rollback) + + return len(self._rollback_stack) + + def _exit(self, n: int): + assert n == len(self._rollback_stack) + + rollback = self._rollback_stack[-1] + self._rollback_stack.pop(-1) + + for k, v in rollback.items(): + setattr(self, k, v) + + def update(self, *, + is_train: Optional[bool] = None, + is_log_parameters: Optional[bool] = None, + is_log_activations: Optional[bool] = None, + is_optimize: Optional[bool] = None): + return Mode(self, + is_train=is_train, + is_log_parameters=is_log_parameters, + is_log_activations=is_log_activations, + is_optimize=is_optimize) + + +class Mode: + def __init__(self, mode: ModeState, **kwargs: any): + self.mode = mode + self.update = {} + for k, v in kwargs.items(): + if v is not None: + self.update[k] = v + + self.idx = -1 + + def __enter__(self): + self.idx = self.mode._enter(self.update) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.mode._exit(self.idx) + + +class ForwardHook: + def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module): + self.mode = mode + self.model_name = model_name + self.name = name + self.module = module + module.register_forward_hook(self) + + def save(self, name: str, output): + if isinstance(output, torch.Tensor): + pytorch_utils.store_var(name, output) + elif isinstance(output, tuple): + for i, o in enumerate(output): + self.save(f"{name}.{i}", o) + + def __call__(self, module, i, o): + if not self.mode.is_log_activations: + return + + self.save(f"module.{self.model_name}.{self.name}", o) + + +def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"): + for name, module in model.named_modules(): + if name == '': + name = 'full' + ForwardHook(mode, model_name, name, module) + + +class Trainer: + def __init__(self, *, + name: str, + mode: ModeState, + data_loader: torch.utils.data.DataLoader, + inner_iterations: int, + state_modules: List[StateModule], + is_track_time: bool, + step: Callable[[any, 'BatchIndex'], None]): + self.is_track_time = is_track_time + self.mode = mode + self.name = name + self.step = step + self.state_modules = state_modules + self.__iterable = None + self.__states = [sm.create_state() for sm in self.state_modules] + self.inner_iterations = inner_iterations + self.data_loader = data_loader + self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations) + + def set_data_loader(self, data_loader: torch.utils.data.DataLoader): + self.data_loader = data_loader + self._batch_index = BatchIndex(len(data_loader), self.inner_iterations) + self.__iterable = None + + def __call__(self): + for sm, s in zip(self.state_modules, self.__states): + sm.set_state(s) + + if self.__iterable is None or self._batch_index.completed: + self.__iterable = iter(self.data_loader) + self._batch_index.reset(len(self.data_loader), self.inner_iterations) + for sm in self.state_modules: + sm.on_epoch_start() + with torch.set_grad_enabled(self.mode.is_train): + self.__iterate() + + if self._batch_index.completed: + for sm in self.state_modules: + sm.on_epoch_end() + + def __iterate(self): + with monit.section(self.name, is_partial=True, is_track=self.is_track_time): + if self._batch_index.idx == 0: + monit.progress(0) + while not self._batch_index.iteration_completed: + batch = next(self.__iterable) + + self.step(batch, self._batch_index) + + self._batch_index.step() + monit.progress(self._batch_index.epoch_progress) + + self._batch_index.step_inner() + + +class BatchIndex: + idx: int + total: int + iteration: int + total_iterations: int + + def __init__(self, total: int, total_iterations: int): + self.total_iterations = total_iterations + self.total = total + + def is_interval(self, interval: int): + if interval <= 0: + return False + if self.idx + 1 == self.total: + return True + else: + return (self.idx + 1) % interval == 0 + + @property + def is_last(self): + return self.idx + 1 == self.total + + @property + def completed(self): + return self.iteration >= self.total_iterations + + @property + def iteration_completed(self): + # // is important so that the last step happens on the last iteration + return self.idx >= (self.iteration + 1) * self.total // self.total_iterations + + @property + def epoch_progress(self): + return self.idx / self.total + + def step(self): + self.idx += 1 + + def step_inner(self): + self.iteration += 1 + + def reset(self, total: int, total_iterations: int): + self.total = total + self.total_iterations = total_iterations + self.idx = 0 + self.iteration = 0 + + +class TrainValidConfigs(TrainingLoopConfigs): + r""" + This is a configurable module that you can extend for experiments that involve a + training and validation datasets (i.e. most DL experiments). + + Arguments: + epochs (int): Number of epochs to train on. Defaults to ``10``. + train_loader (torch.utils.data.DataLoader): Training data loader. + valid_loader (torch.utils.data.DataLoader): Training data loader. + inner_iterations (int): Number of times to switch between training and validation + within an epoch. Defaults to ``1``. + + You can override ``init``, ``step`` functions. There is also a ``sample`` function + that you can override to generate samples ever time it switches between training and validation. + """ + state_modules: List[StateModule] + + mode: ModeState + + epochs: int = 10 + + trainer: Trainer + validator: Trainer + train_loader: torch.utils.data.DataLoader + valid_loader: torch.utils.data.DataLoader + + loop_count = '_data_loop_count' + loop_step = None + + inner_iterations: int = 1 + + is_track_time: bool = False + + def init(self): + pass + + def step(self, batch: Any, batch_idx: BatchIndex): + raise NotImplementedError + + def run_step(self): + for i in range(self.inner_iterations): + with tracker.namespace('sample'): + self.sample() + with self.mode.update(is_train=True): + with tracker.namespace('train'): + self.trainer() + if self.validator: + with tracker.namespace('valid'): + self.validator() + tracker.save() + + def run(self): + with monit.section("Initialize"): + self.init() + _ = self.validator + _ = self.trainer + for _ in self.training_loop: + self.run_step() + + def sample(self): + pass + + +@option(TrainValidConfigs.trainer) +def _default_trainer(c: TrainValidConfigs): + return Trainer(name='Train', + mode=c.mode, + data_loader=c.train_loader, + inner_iterations=c.inner_iterations, + state_modules=c.state_modules, + is_track_time=c.is_track_time, + step=c.step) + + +@option(TrainValidConfigs.validator) +def _default_validator(c: TrainValidConfigs): + return Trainer(name='Valid', + mode=c.mode, + data_loader=c.valid_loader, + inner_iterations=c.inner_iterations, + state_modules=c.state_modules, + is_track_time=c.is_track_time, + step=c.step) + + +@option(TrainValidConfigs.loop_count) +def _data_loop_count(c: TrainValidConfigs): + return c.epochs + + +class SimpleTrainValidConfigs(TrainValidConfigs): + r""" + This is a configurable module that works for many standard DL experiments. + + Arguments: + model: A PyTorch model. + optimizer: A PyTorch optimizer to update model. + device: The device to train the model on. This defaults to a configurable device + loss_function: A function to calculate the loss. This should accept ``model_output, target`` as + arguments. + update_batches (int): Number of batches to accumulate before taking an optimizer step. + Defaults to ``1``. + log_params_updates (int): How often (number of batches) to track model parameters and gradients. + Defaults to a large number; i.e. logs every epoch. + log_activations_batches (int): How often to log model activations. + Defaults to a large number; i.e. logs every epoch. + log_save_batches (int): How often to call :func:`labml.tracker.save`. + """ + optimizer: torch.optim.Adam + model: nn.Module + device: torch.device = DeviceConfigs() + + loss_func: nn.Module + + update_batches: int = 1 + log_params_updates: int = 2 ** 32 # 0 if not + log_activations_batches: int = 2 ** 32 # 0 if not + log_save_batches: int = 1 + + state_modules: List[StateModule] = [] + + def init(self): + pass + + def step(self, batch: Any, batch_idx: BatchIndex): + self.model.train(self.mode.is_train) + data, target = batch[0].to(self.device), batch[1].to(self.device) + + if self.mode.is_train: + tracker.add_global_step(len(data)) + + is_log_activations = batch_idx.is_interval(self.log_activations_batches) + with monit.section("model"): + with self.mode.update(is_log_activations=is_log_activations): + output = self.model(data) + + loss = self.loss_func(output, target) + tracker.add("loss.", loss) + + if self.mode.is_train: + with monit.section('backward'): + loss.backward() + + if batch_idx.is_interval(self.update_batches): + with monit.section('optimize'): + self.optimizer.step() + if batch_idx.is_interval(self.log_params_updates): + tracker.add('model', self.model) + self.optimizer.zero_grad() + + if batch_idx.is_interval(self.log_save_batches): + tracker.save() + + +meta_config(SimpleTrainValidConfigs.update_batches, + SimpleTrainValidConfigs.log_params_updates, + SimpleTrainValidConfigs.log_activations_batches) + + +@option(SimpleTrainValidConfigs.optimizer) +def _default_optimizer(c: SimpleTrainValidConfigs): + from .optimizer import OptimizerConfigs + opt_conf = OptimizerConfigs() + opt_conf.parameters = c.model.parameters() + return opt_conf diff --git a/labml_nn/lora/experiment.py b/labml_nn/lora/experiment.py index 0746f060..e6f05cf2 100644 --- a/labml_nn/lora/experiment.py +++ b/labml_nn/lora/experiment.py @@ -19,7 +19,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from labml import lab, monit, tracker from labml.configs import BaseConfigs, option from labml.utils.download import download_file -from labml_helpers.device import DeviceConfigs +from labml_nn.helpers.device import DeviceConfigs from labml_nn.lora.gpt2 import GPTModel diff --git a/labml_nn/optimizers/mnist_experiment.py b/labml_nn/optimizers/mnist_experiment.py index cd3338d9..74e2af0c 100644 --- a/labml_nn/optimizers/mnist_experiment.py +++ b/labml_nn/optimizers/mnist_experiment.py @@ -11,11 +11,10 @@ import torch.utils.data from labml import experiment, tracker from labml.configs import option -from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.device import DeviceConfigs -from labml_helpers.metrics.accuracy import Accuracy -from labml_helpers.seed import SeedConfigs -from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs +from labml_nn.helpers.datasets import MNISTConfigs +from labml_nn.helpers.device import DeviceConfigs +from labml_nn.helpers.metrics import Accuracy +from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs from labml_nn.optimizers.configs import OptimizerConfigs @@ -48,7 +47,6 @@ class Configs(MNISTConfigs, TrainValidConfigs): """ optimizer: torch.optim.Adam model: nn.Module - set_seed = SeedConfigs() device: torch.device = DeviceConfigs() epochs: int = 10 @@ -126,7 +124,6 @@ def main(): # Specify the optimizer 'optimizer.optimizer': 'Adam', 'optimizer.learning_rate': 1.5e-4}) - conf.set_seed.set() experiment.add_pytorch_models(dict(model=conf.model)) with experiment.start(): conf.run() diff --git a/labml_nn/optimizers/performance_test.py b/labml_nn/optimizers/performance_test.py index 918a1e49..629efea3 100644 --- a/labml_nn/optimizers/performance_test.py +++ b/labml_nn/optimizers/performance_test.py @@ -18,7 +18,7 @@ MyAdam...[DONE] 1,192.89ms import torch import torch.nn as nn -from labml_helpers.device import DeviceInfo +from labml_nn.helpers.device import DeviceInfo from torch.optim import Adam as TorchAdam from labml import monit diff --git a/labml_nn/rl/dqn/experiment.py b/labml_nn/rl/dqn/experiment.py index 9d3ad38c..2a3af438 100644 --- a/labml_nn/rl/dqn/experiment.py +++ b/labml_nn/rl/dqn/experiment.py @@ -17,7 +17,7 @@ import torch from labml import tracker, experiment, logger, monit from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam -from labml_helpers.schedule import Piecewise +from labml_nn.helpers.schedule import Piecewise from labml_nn.rl.dqn import QFuncLoss from labml_nn.rl.dqn.model import Model from labml_nn.rl.dqn.replay_buffer import ReplayBuffer diff --git a/labml_nn/sampling/experiment_tiny.py b/labml_nn/sampling/experiment_tiny.py index 989432af..3fd352ad 100644 --- a/labml_nn/sampling/experiment_tiny.py +++ b/labml_nn/sampling/experiment_tiny.py @@ -5,7 +5,7 @@ import torch from labml import experiment, monit from labml import logger from labml.logger import Text -from labml_helpers.datasets.text import TextDataset +from labml_nn.helpers.datasets import TextDataset from labml_nn.sampling import Sampler from labml_nn.sampling.greedy import GreedySampler from labml_nn.sampling.nucleus import NucleusSampler diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index 3bee40c4..2234f70f 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -39,9 +39,9 @@ from matplotlib import pyplot as plt import torch import torch.nn as nn from labml import lab, experiment, tracker, monit -from labml_helpers.device import DeviceConfigs -from labml_helpers.optimizer import OptimizerConfigs -from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex +from labml_nn.helpers.device import DeviceConfigs +from labml_nn.helpers.optimizer import OptimizerConfigs +from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex from torch import optim from torch.utils.data import Dataset, DataLoader diff --git a/labml_nn/transformers/alibi/experiment.py b/labml_nn/transformers/alibi/experiment.py index e1230921..28eb0477 100644 --- a/labml_nn/transformers/alibi/experiment.py +++ b/labml_nn/transformers/alibi/experiment.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader from labml import experiment, tracker from labml.configs import option, calculate -from labml_helpers.datasets.text import SequentialUnBatchedDataset +from labml_nn.helpers.datasets import SequentialUnBatchedDataset from labml_nn.transformers.alibi import AlibiMultiHeadAttention from labml_nn.experiments.nlp_autoregression import transpose_batch from labml_nn.transformers import TransformerConfigs diff --git a/labml_nn/transformers/basic/with_sophia.py b/labml_nn/transformers/basic/with_sophia.py index 65c119e1..83e3bc60 100644 --- a/labml_nn/transformers/basic/with_sophia.py +++ b/labml_nn/transformers/basic/with_sophia.py @@ -13,7 +13,7 @@ on an NLP auto-regression task (with Tiny Shakespeare dataset) with [Sophia-G op import torch from labml import experiment, tracker -from labml_helpers.train_valid import BatchIndex +from labml_nn.helpers.trainer import BatchIndex from labml_nn.optimizers.sophia import Sophia from labml_nn.transformers.basic.autoregressive_experiment import Configs as TransformerAutoRegressionConfigs diff --git a/labml_nn/transformers/compressive/__init__.py b/labml_nn/transformers/compressive/__init__.py index 28e8ee6f..96339e0c 100644 --- a/labml_nn/transformers/compressive/__init__.py +++ b/labml_nn/transformers/compressive/__init__.py @@ -56,7 +56,6 @@ import torch import torch.nn.functional as F from torch import nn -from labml_helpers.module import TypedModuleList from labml_nn.transformers.feed_forward import FeedForward from labml_nn.transformers.mha import PrepareForMultiHeadAttention from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention @@ -233,7 +232,7 @@ class AttentionReconstructionLoss: attention reconstruction loss, we detach all other parameters except $f_c$ from the gradient computation. """ - def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]): + def __init__(self, layers: nn.ModuleList): """ `layers` is the list of Compressive Transformer layers """ diff --git a/labml_nn/transformers/compressive/experiment.py b/labml_nn/transformers/compressive/experiment.py index ab563ead..cd1d2b84 100644 --- a/labml_nn/transformers/compressive/experiment.py +++ b/labml_nn/transformers/compressive/experiment.py @@ -16,8 +16,8 @@ import torch.nn as nn from labml import experiment, tracker, monit, logger from labml.configs import option from labml.logger import Text -from labml_helpers.metrics.simple_state import SimpleStateModule -from labml_helpers.train_valid import BatchIndex, hook_model_outputs +from labml_nn.helpers.metrics import SimpleStateModule +from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \ CompressiveTransformerLayer, Conv1dCompression diff --git a/labml_nn/transformers/mlm/experiment.py b/labml_nn/transformers/mlm/experiment.py index 862887bb..9267f737 100644 --- a/labml_nn/transformers/mlm/experiment.py +++ b/labml_nn/transformers/mlm/experiment.py @@ -16,8 +16,8 @@ from torch import nn from labml import experiment, tracker, logger from labml.configs import option from labml.logger import Text -from labml_helpers.metrics.accuracy import Accuracy -from labml_helpers.train_valid import BatchIndex +from labml_nn.helpers.metrics import Accuracy +from labml_nn.helpers.trainer import BatchIndex from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.transformers import Encoder, Generator from labml_nn.transformers import TransformerConfigs diff --git a/labml_nn/transformers/retro/database.py b/labml_nn/transformers/retro/database.py index aed52f68..90d0a200 100644 --- a/labml_nn/transformers/retro/database.py +++ b/labml_nn/transformers/retro/database.py @@ -20,7 +20,7 @@ import numpy as np import torch from labml import lab, monit -from labml_helpers.datasets.text import TextFileDataset +from labml_nn.helpers.datasets import TextFileDataset from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings diff --git a/labml_nn/transformers/retro/dataset.py b/labml_nn/transformers/retro/dataset.py index 2ef257a6..42ff9f47 100644 --- a/labml_nn/transformers/retro/dataset.py +++ b/labml_nn/transformers/retro/dataset.py @@ -20,7 +20,7 @@ import torch from torch.utils.data import Dataset as PyTorchDataset from labml import lab, monit -from labml_helpers.datasets.text import TextFileDataset, TextDataset +from labml_nn.helpers.datasets import TextFileDataset, TextDataset from labml_nn.transformers.retro.database import RetroIndex diff --git a/labml_nn/transformers/retro/train.py b/labml_nn/transformers/retro/train.py index 7dbd2b18..ec4d7bb4 100644 --- a/labml_nn/transformers/retro/train.py +++ b/labml_nn/transformers/retro/train.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler from labml import monit, lab, tracker, experiment, logger from labml.logger import Text -from labml_helpers.datasets.text import TextFileDataset +from labml_nn.helpers.datasets import TextFileDataset from labml_nn.optimizers.noam import Noam from labml_nn.transformers.retro import model as retro from labml_nn.transformers.retro.dataset import Dataset, RetroIndex diff --git a/labml_nn/transformers/switch/experiment.py b/labml_nn/transformers/switch/experiment.py index cbc788c4..1490f803 100644 --- a/labml_nn/transformers/switch/experiment.py +++ b/labml_nn/transformers/switch/experiment.py @@ -16,7 +16,7 @@ import torch.nn as nn from labml import experiment, tracker from labml.configs import option -from labml_helpers.train_valid import BatchIndex +from labml_nn.helpers.trainer import BatchIndex from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs diff --git a/labml_nn/transformers/xl/experiment.py b/labml_nn/transformers/xl/experiment.py index a88cc582..84107b78 100644 --- a/labml_nn/transformers/xl/experiment.py +++ b/labml_nn/transformers/xl/experiment.py @@ -16,8 +16,8 @@ from labml.logger import Text from labml import experiment, tracker, monit, logger from labml.configs import option -from labml_helpers.metrics.simple_state import SimpleStateModule -from labml_helpers.train_valid import BatchIndex, hook_model_outputs +from labml_nn.helpers.metrics import SimpleStateModule +from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer diff --git a/labml_nn/uncertainty/evidence/experiment.py b/labml_nn/uncertainty/evidence/experiment.py index b1eb6bdc..edcbef50 100644 --- a/labml_nn/uncertainty/evidence/experiment.py +++ b/labml_nn/uncertainty/evidence/experiment.py @@ -18,8 +18,8 @@ import torch.utils.data from labml import tracker, experiment from labml.configs import option, calculate -from labml_helpers.schedule import Schedule, RelativePiecewise -from labml_helpers.train_valid import BatchIndex +from labml_nn.helpers.schedule import Schedule, RelativePiecewise +from labml_nn.helpers.trainer import BatchIndex from labml_nn.experiments.mnist import MNISTConfigs from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \ CrossEntropyBayesRisk, SquaredErrorBayesRisk @@ -178,7 +178,7 @@ def kl_div_coef(c: Configs): ### KL Divergence Loss Coefficient Schedule """ - # Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise) + # Create a [relative piecewise schedule](../../helpers/schedule.html) return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset)) diff --git a/labml_nn/unet/experiment.py b/labml_nn/unet/experiment.py index 5095c86d..06dbe269 100644 --- a/labml_nn/unet/experiment.py +++ b/labml_nn/unet/experiment.py @@ -24,7 +24,7 @@ from torch import nn from labml import lab, tracker, experiment, monit from labml.configs import BaseConfigs -from labml_helpers.device import DeviceConfigs +from labml_nn.helpers.device import DeviceConfigs from labml_nn.unet.carvana import CarvanaDataset from labml_nn.unet import UNet @@ -34,7 +34,7 @@ class Configs(BaseConfigs): ## Configurations """ # Device to train the model on. - # [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs) + # [`DeviceConfigs`](../helpers/device.html) # picks up an available CUDA device or defaults to CPU. device: torch.device = DeviceConfigs() diff --git a/labml_nn/utils/__init__.py b/labml_nn/utils/__init__.py index 982a11b9..459cd04f 100644 --- a/labml_nn/utils/__init__.py +++ b/labml_nn/utils/__init__.py @@ -8,19 +8,18 @@ summary: A bunch of utility functions and classes """ import copy - +from torch import nn from torch.utils.data import Dataset, IterableDataset -from labml_helpers.module import M, TypedModuleList -def clone_module_list(module: M, n: int) -> TypedModuleList[M]: +def clone_module_list(module: nn.Module, n: int) -> nn.ModuleList: """ ## Clone Module Make a `nn.ModuleList` with clones of a given module """ - return TypedModuleList([copy.deepcopy(module) for _ in range(n)]) + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) def cycle_dataloader(data_loader): diff --git a/setup.py b/setup.py index 93160d8d..8f87a39c 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,6 @@ setuptools.setup( 'test', 'test.*')), install_requires=['labml==0.4.168', - 'labml-helpers==0.4.89', 'torch', 'torchtext', 'torchvision',