remove labml_helpers dep

This commit is contained in:
Varuna Jayasiri
2025-07-20 08:56:03 +05:30
parent b1ba92c166
commit 5bdedcffec
39 changed files with 1356 additions and 71 deletions

View File

@ -17,8 +17,8 @@ from torch import nn
from torch.utils.data import DataLoader
from labml import tracker, experiment
from labml_helpers.metrics.accuracy import AccuracyDirect
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.helpers.metrics import AccuracyDirect
from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex
from labml_nn.adaptive_computation.parity import ParityDataset
from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLoss
@ -26,7 +26,7 @@ from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, Reconstruc
class Configs(SimpleTrainValidConfigs):
"""
Configurations with a
[simple training loop](https://docs.labml.ai/api/helpers.html#labml_helpers.train_valid.SimpleTrainValidConfigs)
[simple training loop](../../helpers/trainer.html)
"""
# Number of epochs

View File

@ -16,13 +16,12 @@ from typing import Any
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from labml import experiment, tracker
from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.metrics.accuracy import AccuracyDirect
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.capsule_networks import Squash, Router, MarginLoss
from labml_nn.helpers.datasets import MNISTConfigs
from labml_nn.helpers.metrics import AccuracyDirect
from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex
class MNISTCapsuleNetworkModel(nn.Module):

View File

@ -26,7 +26,7 @@ from PIL import Image
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs, option
from labml_helpers.device import DeviceConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.diffusion.ddpm import DenoiseDiffusion
from labml_nn.diffusion.ddpm.unet import UNet
@ -36,7 +36,7 @@ class Configs(BaseConfigs):
## Configurations
"""
# Device to train the model on.
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
# [`DeviceConfigs`](../../device.html)
# picks up an available CUDA device or defaults to CPU.
device: torch.device = DeviceConfigs()

View File

@ -75,7 +75,7 @@ from torch import nn
from labml import experiment, tracker
from labml.configs import option
from labml_helpers.train_valid import BatchIndex
from labml_nn.helpers.trainer import BatchIndex
from labml_nn.distillation.large import LargeModel
from labml_nn.distillation.small import SmallModel
from labml_nn.experiments.cifar10 import CIFAR10Configs

View File

@ -13,7 +13,7 @@ import torch.nn as nn
from labml import lab
from labml.configs import option
from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs
from labml_nn.experiments.mnist import MNISTConfigs
@ -21,8 +21,7 @@ class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
"""
## Configurations
This extends from CIFAR 10 dataset configurations from
[`labml_helpers`](https://github.com/labmlai/labml/tree/master/helpers)
This extends from [CIFAR 10 dataset configurations](../helpers/datasets.html)
and [`MNISTConfigs`](mnist.html).
"""
# Use CIFAR10 dataset by default

View File

@ -13,10 +13,10 @@ import torch.utils.data
from labml import tracker
from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
from labml_nn.optimizers.configs import OptimizerConfigs

View File

@ -17,10 +17,10 @@ from torch.utils.data import DataLoader, RandomSampler
from labml import lab, monit, logger, tracker
from labml.configs import option
from labml.logger import Text
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
from labml_helpers.device import DeviceConfigs
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs

View File

@ -20,9 +20,9 @@ from torchtext.vocab import Vocab
from labml import lab, tracker, monit
from labml.configs import option
from labml_helpers.device import DeviceConfigs
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.optimizers.configs import OptimizerConfigs

View File

@ -49,7 +49,7 @@ from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml.utils.download import download_file
from labml.utils.pytorch import get_modules
from labml_helpers.device import DeviceConfigs
from labml_nn.helpers.device import DeviceConfigs
class GeneratorResNet(nn.Module):

View File

@ -16,10 +16,10 @@ from torchvision import transforms
from labml import tracker, monit, experiment
from labml.configs import option, calculate
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.helpers.datasets import MNISTConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.optimizer import OptimizerConfigs
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss

View File

@ -39,8 +39,8 @@ from PIL import Image
from labml import tracker, lab, monit, experiment
from labml.configs import BaseConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.train_valid import ModeState, hook_model_outputs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.trainer import ModeState, hook_model_outputs
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
from labml_nn.utils import cycle_dataloader
@ -88,7 +88,7 @@ class Configs(BaseConfigs):
"""
# Device to train the model on.
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
# [`DeviceConfigs`](../../helpers/device.html)
# picks up an available CUDA device or defaults to CPU.
device: torch.device = DeviceConfigs()

View File

@ -17,7 +17,7 @@ from torch import nn
from labml import lab, monit, tracker, experiment
from labml.configs import BaseConfigs, option, calculate
from labml.utils import download
from labml_helpers.device import DeviceConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.graphs.gat import GraphAttentionLayer
from labml_nn.optimizers.configs import OptimizerConfigs

View File

View File

@ -0,0 +1,322 @@
import random
from pathlib import PurePath, Path
from typing import List, Callable, Dict, Optional
from torchvision import datasets, transforms
import torch
from labml import lab
from labml import monit
from labml.configs import BaseConfigs
from labml.configs import aggregate, option
from labml.utils.download import download_file
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset, Dataset
def _dataset(is_train, transform):
return datasets.MNIST(str(lab.get_data_path()),
train=is_train,
download=True,
transform=transform)
class MNISTConfigs(BaseConfigs):
"""
Configurable MNIST data set.
Arguments:
dataset_name (str): name of the data set, ``MNIST``
dataset_transforms (torchvision.transforms.Compose): image transformations
train_dataset (torchvision.datasets.MNIST): training dataset
valid_dataset (torchvision.datasets.MNIST): validation dataset
train_loader (torch.utils.data.DataLoader): training data loader
valid_loader (torch.utils.data.DataLoader): validation data loader
train_batch_size (int): training batch size
valid_batch_size (int): validation batch size
train_loader_shuffle (bool): whether to shuffle training data
valid_loader_shuffle (bool): whether to shuffle validation data
"""
dataset_name: str = 'MNIST'
dataset_transforms: transforms.Compose
train_dataset: datasets.MNIST
valid_dataset: datasets.MNIST
train_loader: DataLoader
valid_loader: DataLoader
train_batch_size: int = 64
valid_batch_size: int = 1024
train_loader_shuffle: bool = True
valid_loader_shuffle: bool = False
@option(MNISTConfigs.dataset_transforms)
def mnist_transforms():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
@option(MNISTConfigs.train_dataset)
def mnist_train_dataset(c: MNISTConfigs):
return _dataset(True, c.dataset_transforms)
@option(MNISTConfigs.valid_dataset)
def mnist_valid_dataset(c: MNISTConfigs):
return _dataset(False, c.dataset_transforms)
@option(MNISTConfigs.train_loader)
def mnist_train_loader(c: MNISTConfigs):
return DataLoader(c.train_dataset,
batch_size=c.train_batch_size,
shuffle=c.train_loader_shuffle)
@option(MNISTConfigs.valid_loader)
def mnist_valid_loader(c: MNISTConfigs):
return DataLoader(c.valid_dataset,
batch_size=c.valid_batch_size,
shuffle=c.valid_loader_shuffle)
aggregate(MNISTConfigs.dataset_name, 'MNIST',
(MNISTConfigs.dataset_transforms, 'mnist_transforms'),
(MNISTConfigs.train_dataset, 'mnist_train_dataset'),
(MNISTConfigs.valid_dataset, 'mnist_valid_dataset'),
(MNISTConfigs.train_loader, 'mnist_train_loader'),
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
def _dataset(is_train, transform):
return datasets.CIFAR10(str(lab.get_data_path()),
train=is_train,
download=True,
transform=transform)
class CIFAR10Configs(BaseConfigs):
"""
Configurable CIFAR 10 data set.
Arguments:
dataset_name (str): name of the data set, ``CIFAR10``
dataset_transforms (torchvision.transforms.Compose): image transformations
train_dataset (torchvision.datasets.CIFAR10): training dataset
valid_dataset (torchvision.datasets.CIFAR10): validation dataset
train_loader (torch.utils.data.DataLoader): training data loader
valid_loader (torch.utils.data.DataLoader): validation data loader
train_batch_size (int): training batch size
valid_batch_size (int): validation batch size
train_loader_shuffle (bool): whether to shuffle training data
valid_loader_shuffle (bool): whether to shuffle validation data
"""
dataset_name: str = 'CIFAR10'
dataset_transforms: transforms.Compose
train_dataset: datasets.CIFAR10
valid_dataset: datasets.CIFAR10
train_loader: DataLoader
valid_loader: DataLoader
train_batch_size: int = 64
valid_batch_size: int = 1024
train_loader_shuffle: bool = True
valid_loader_shuffle: bool = False
@CIFAR10Configs.calc(CIFAR10Configs.dataset_transforms)
def cifar10_transforms():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
def cifar10_train_dataset(c: CIFAR10Configs):
return _dataset(True, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
def cifar10_valid_dataset(c: CIFAR10Configs):
return _dataset(False, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
def cifar10_train_loader(c: CIFAR10Configs):
return DataLoader(c.train_dataset,
batch_size=c.train_batch_size,
shuffle=c.train_loader_shuffle)
@CIFAR10Configs.calc(CIFAR10Configs.valid_loader)
def cifar10_valid_loader(c: CIFAR10Configs):
return DataLoader(c.valid_dataset,
batch_size=c.valid_batch_size,
shuffle=c.valid_loader_shuffle)
CIFAR10Configs.aggregate(CIFAR10Configs.dataset_name, 'CIFAR10',
(CIFAR10Configs.dataset_transforms, 'cifar10_transforms'),
(CIFAR10Configs.train_dataset, 'cifar10_train_dataset'),
(CIFAR10Configs.valid_dataset, 'cifar10_valid_dataset'),
(CIFAR10Configs.train_loader, 'cifar10_train_loader'),
(CIFAR10Configs.valid_loader, 'cifar10_valid_loader'))
class TextDataset:
itos: List[str]
stoi: Dict[str, int]
n_tokens: int
train: str
valid: str
standard_tokens: List[str] = []
@staticmethod
def load(path: PurePath):
with open(str(path), 'r') as f:
return f.read()
def __init__(self, path: PurePath, tokenizer: Callable, train: str, valid: str, test: str, *,
n_tokens: Optional[int] = None,
stoi: Optional[Dict[str, int]] = None,
itos: Optional[List[str]] = None):
self.test = test
self.valid = valid
self.train = train
self.tokenizer = tokenizer
self.path = path
if n_tokens or stoi or itos:
assert stoi and itos and n_tokens
self.n_tokens = n_tokens
self.stoi = stoi
self.itos = itos
else:
self.n_tokens = len(self.standard_tokens)
self.stoi = {t: i for i, t in enumerate(self.standard_tokens)}
with monit.section("Tokenize"):
tokens = self.tokenizer(self.train) + self.tokenizer(self.valid)
tokens = sorted(list(set(tokens)))
for t in monit.iterate("Build vocabulary", tokens):
self.stoi[t] = self.n_tokens
self.n_tokens += 1
self.itos = [''] * self.n_tokens
for t, n in self.stoi.items():
self.itos[n] = t
def text_to_i(self, text: str) -> torch.Tensor:
tokens = self.tokenizer(text)
return torch.tensor([self.stoi[s] for s in tokens if s in self.stoi], dtype=torch.long)
def __repr__(self):
return f'{len(self.train) / 1_000_000 :,.2f}M, {len(self.valid) / 1_000_000 :,.2f}M - {str(self.path)}'
class SequentialDataLoader(IterableDataset):
def __init__(self, *, text: str, dataset: TextDataset,
batch_size: int, seq_len: int):
self.seq_len = seq_len
data = dataset.text_to_i(text)
n_batch = data.shape[0] // batch_size
data = data.narrow(0, 0, n_batch * batch_size)
data = data.view(batch_size, -1).t().contiguous()
self.data = data
def __len__(self):
return self.data.shape[0] // self.seq_len
def __iter__(self):
self.idx = 0
return self
def __next__(self):
if self.idx >= self.data.shape[0] - 1:
raise StopIteration()
seq_len = min(self.seq_len, self.data.shape[0] - 1 - self.idx)
i = self.idx + seq_len
data = self.data[self.idx: i]
target = self.data[self.idx + 1: i + 1]
self.idx = i
return data, target
def __getitem__(self, idx):
seq_len = min(self.seq_len, self.data.shape[0] - 1 - idx)
i = idx + seq_len
data = self.data[idx: i]
target = self.data[idx + 1: i + 1]
return data, target
class SequentialUnBatchedDataset(Dataset):
def __init__(self, *, text: str, dataset: TextDataset,
seq_len: int,
is_random_offset: bool = True):
self.is_random_offset = is_random_offset
self.seq_len = seq_len
self.data = dataset.text_to_i(text)
def __len__(self):
return (self.data.shape[0] - 1) // self.seq_len
def __getitem__(self, idx):
start = idx * self.seq_len
assert start + self.seq_len + 1 <= self.data.shape[0]
if self.is_random_offset:
start += random.randint(0, min(self.seq_len - 1, self.data.shape[0] - (start + self.seq_len + 1)))
end = start + self.seq_len
data = self.data[start: end]
target = self.data[start + 1: end + 1]
return data, target
class TextFileDataset(TextDataset):
standard_tokens = []
def __init__(self, path: PurePath, tokenizer: Callable, *,
url: Optional[str] = None,
filter_subset: Optional[int] = None):
path = Path(path)
if not path.exists():
if not url:
raise FileNotFoundError(str(path))
else:
download_file(url, path)
with monit.section("Load data"):
text = self.load(path)
if filter_subset:
text = text[:filter_subset]
split = int(len(text) * .9)
train = text[:split]
valid = text[split:]
super().__init__(path, tokenizer, train, valid, '')
def _test_tiny_shakespeare():
from labml import lab
_ = TextFileDataset(lab.get_data_path() / 'tiny_shakespeare.txt', lambda x: list(x),
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
if __name__ == '__main__':
_test_tiny_shakespeare()

View File

@ -0,0 +1,70 @@
import torch
from labml.configs import BaseConfigs, hyperparams, option
class DeviceInfo:
def __init__(self, *,
use_cuda: bool,
cuda_device: int):
self.use_cuda = use_cuda
self.cuda_device = cuda_device
self.cuda_count = torch.cuda.device_count()
self.is_cuda = self.use_cuda and torch.cuda.is_available()
if not self.is_cuda:
self.device = torch.device('cpu')
else:
if self.cuda_device < self.cuda_count:
self.device = torch.device('cuda', self.cuda_device)
else:
self.device = torch.device('cuda', self.cuda_count - 1)
def __str__(self):
if not self.is_cuda:
return "CPU"
if self.cuda_device < self.cuda_count:
return f"GPU:{self.cuda_device} - {torch.cuda.get_device_name(self.cuda_device)}"
else:
return (f"GPU:{self.cuda_count - 1}({self.cuda_device}) "
f"- {torch.cuda.get_device_name(self.cuda_count - 1)}")
class DeviceConfigs(BaseConfigs):
r"""
This is a configurable module to get a single device to train model on.
It can pick up CUDA devices and it will fall back to CPU if they are not available.
It has other small advantages such as being able to view the
actual device name on configurations view of
`labml app <https://github.com/labmlai/labml/tree/master/app>`_
Arguments:
cuda_device (int): The CUDA device number. Defaults to ``0``.
use_cuda (bool): Whether to use CUDA devices. Defaults to ``True``.
"""
cuda_device: int = 0
use_cuda: bool = True
device_info: DeviceInfo
device: torch.device
def __init__(self):
super().__init__(_primary='device')
@option(DeviceConfigs.device)
def _device(c: DeviceConfigs):
return c.device_info.device
hyperparams(DeviceConfigs.cuda_device, DeviceConfigs.use_cuda,
is_hyperparam=False)
@option(DeviceConfigs.device_info)
def _device_info(c: DeviceConfigs):
return DeviceInfo(use_cuda=c.use_cuda,
cuda_device=c.cuda_device)

122
labml_nn/helpers/metrics.py Normal file
View File

@ -0,0 +1,122 @@
import dataclasses
from abc import ABC
import torch
from labml import tracker
class StateModule:
def __init__(self):
pass
# def __call__(self):
# raise NotImplementedError
def create_state(self) -> any:
raise NotImplementedError
def set_state(self, data: any):
raise NotImplementedError
def on_epoch_start(self):
raise NotImplementedError
def on_epoch_end(self):
raise NotImplementedError
class Metric(StateModule, ABC):
def track(self):
pass
@dataclasses.dataclass
class AccuracyState:
samples: int = 0
correct: int = 0
def reset(self):
self.samples = 0
self.correct = 0
class Accuracy(Metric):
data: AccuracyState
def __init__(self, ignore_index: int = -1):
super().__init__()
self.ignore_index = ignore_index
def __call__(self, output: torch.Tensor, target: torch.Tensor):
output = output.view(-1, output.shape[-1])
target = target.view(-1)
pred = output.argmax(dim=-1)
mask = target == self.ignore_index
pred.masked_fill_(mask, self.ignore_index)
n_masked = mask.sum().item()
self.data.correct += pred.eq(target).sum().item() - n_masked
self.data.samples += len(target) - n_masked
def create_state(self):
return AccuracyState()
def set_state(self, data: any):
self.data = data
def on_epoch_start(self):
self.data.reset()
def on_epoch_end(self):
self.track()
def track(self):
if self.data.samples == 0:
return
tracker.add("accuracy.", self.data.correct / self.data.samples)
class AccuracyMovingAvg(Metric):
def __init__(self, ignore_index: int = -1, queue_size: int = 5):
super().__init__()
self.ignore_index = ignore_index
tracker.set_queue('accuracy.*', queue_size, is_print=True)
def __call__(self, output: torch.Tensor, target: torch.Tensor):
output = output.view(-1, output.shape[-1])
target = target.view(-1)
pred = output.argmax(dim=-1)
mask = target == self.ignore_index
pred.masked_fill_(mask, self.ignore_index)
n_masked = mask.sum().item()
if len(target) - n_masked > 0:
tracker.add('accuracy.', (pred.eq(target).sum().item() - n_masked) / (len(target) - n_masked))
def create_state(self):
return None
def set_state(self, data: any):
pass
def on_epoch_start(self):
pass
def on_epoch_end(self):
pass
class BinaryAccuracy(Accuracy):
def __call__(self, output: torch.Tensor, target: torch.Tensor):
pred = output.view(-1) > 0
target = target.view(-1)
self.data.correct += pred.eq(target).sum().item()
self.data.samples += len(target)
class AccuracyDirect(Accuracy):
data: AccuracyState
def __call__(self, output: torch.Tensor, target: torch.Tensor):
output = output.view(-1)
target = target.view(-1)
self.data.correct += output.eq(target).sum().item()
self.data.samples += len(target)

View File

@ -0,0 +1,97 @@
from typing import Tuple
import torch
from labml import tracker
from labml.configs import BaseConfigs, option, meta_config
class OptimizerConfigs(BaseConfigs):
r"""
This creates a configurable optimizer.
Arguments:
learning_rate (float): Learning rate of the optimizer. Defaults to ``0.01``.
momentum (float): Momentum of the optimizer. Defaults to ``0.5``.
parameters: Model parameters to optimize.
d_model (int): Embedding size of the model (for Noam optimizer).
betas (Tuple[float, float]): Betas for Adam optimizer. Defaults to ``(0.9, 0.999)``.
eps (float): Epsilon for Adam/RMSProp optimizers. Defaults to ``1e-8``.
step_factor (int): Step factor for Noam optimizer. Defaults to ``1024``.
Also there is a better (more options) implementation in ``labml_nn``.
`We recommend using that <https://nn.labml.ai/optimizers/configs.html>`_.
"""
optimizer: torch.optim.Adam
learning_rate: float = 0.01
momentum: float = 0.5
parameters: any
d_model: int
betas: Tuple[float, float] = (0.9, 0.999)
eps: float = 1e-8
step_factor: int = 1024
def __init__(self):
super().__init__(_primary='optimizer')
meta_config(OptimizerConfigs.parameters)
@option(OptimizerConfigs.optimizer, 'SGD')
def sgd_optimizer(c: OptimizerConfigs):
return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum)
@option(OptimizerConfigs.optimizer, 'Adam')
def adam_optimizer(c: OptimizerConfigs):
return torch.optim.Adam(c.parameters, lr=c.learning_rate,
betas=c.betas, eps=c.eps)
class NoamOpt:
def __init__(self, model_size: int, learning_rate: float, warmup: int, step_factor: int, optimizer):
self.step_factor = step_factor
self.optimizer = optimizer
self.warmup = warmup
self.learning_rate = learning_rate
self.model_size = model_size
self._rate = 0
def step(self):
rate = self.rate(tracker.get_global_step() / self.step_factor)
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step):
factor = self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))
return self.learning_rate * factor
def zero_grad(self):
self.optimizer.zero_grad()
@option(OptimizerConfigs.optimizer, 'Noam')
def noam_optimizer(c: OptimizerConfigs):
optimizer = torch.optim.Adam(c.parameters, lr=0.0, betas=c.betas, eps=c.eps)
return NoamOpt(c.d_model, 1, 2000, c.step_factor, optimizer)
def _test_noam_optimizer():
import matplotlib.pyplot as plt
import numpy as np
opts = [NoamOpt(512, 1, 4000, None),
NoamOpt(512, 1, 8000, None),
NoamOpt(2048, 1, 2000, None)]
plt.plot(np.arange(1, 20000), [[opt.rate(i) for opt in opts] for i in range(1, 20000)])
plt.legend(["512:4000", "512:8000", "256:4000"])
plt.title("Optimizer")
plt.show()
if __name__ == '__main__':
_test_noam_optimizer()

View File

@ -0,0 +1,84 @@
from typing import Tuple, List
class Schedule:
def __call__(self, x):
raise NotImplementedError()
class Flat(Schedule):
def __init__(self, value):
self.__value = value
def __call__(self, x):
return self.__value
def __str__(self):
return f"Schedule({self.__value})"
class Dynamic(Schedule):
def __init__(self, value):
self.__value = value
def __call__(self, x):
return self.__value
def update(self, value):
self.__value = value
def __str__(self):
return "Dynamic"
class Piecewise(Schedule):
"""
## Piecewise schedule
"""
def __init__(self, endpoints: List[Tuple[float, float]], outside_value: float = None):
"""
### Initialize
`endpoints` is list of pairs `(x, y)`.
The values between endpoints are linearly interpolated.
`y` values outside the range covered by `x` are
`outside_value`.
"""
# `(x, y)` pairs should be sorted
indexes = [e[0] for e in endpoints]
assert indexes == sorted(indexes)
self._outside_value = outside_value
self._endpoints = endpoints
def __call__(self, x):
"""
### Find `y` for given `x`
"""
# iterate through each segment
for (x1, y1), (x2, y2) in zip(self._endpoints[:-1], self._endpoints[1:]):
# interpolate if `x` is within the segment
if x1 <= x < x2:
dx = float(x - x1) / (x2 - x1)
return y1 + dx * (y2 - y1)
# return outside value otherwise
return self._outside_value
def __str__(self):
endpoints = ", ".join([f"({e[0]}, {e[1]})" for e in self._endpoints])
return f"Schedule[{endpoints}, {self._outside_value}]"
class RelativePiecewise(Piecewise):
def __init__(self, relative_endpoits: List[Tuple[float, float]], total_steps: int):
endpoints = []
for e in relative_endpoits:
index = int(total_steps * e[0])
assert index >= 0
endpoints.append((index, e[1]))
super().__init__(endpoints, outside_value=relative_endpoits[-1][1])

598
labml_nn/helpers/trainer.py Normal file
View File

@ -0,0 +1,598 @@
import signal
import typing
from typing import Dict, List, Callable
from typing import Optional, Tuple, Any, Collection
import labml.utils.pytorch as pytorch_utils
import torch.optim
import torch.optim
import torch.utils.data
import torch.utils.data
from labml import tracker, logger, experiment, monit
from labml.configs import BaseConfigs, meta_config, option
from labml.internal.monitor import Loop
from labml.logger import Text
from torch import nn
from .device import DeviceConfigs
from .metrics import StateModule
class TrainingLoopIterator(Collection):
def __init__(self, start: int, total: int, step: Optional[int]):
self.step = step
self.total = total
self.start = start
self.i = None
def __iter__(self):
self.i = None
return self
def __next__(self):
if self.step is not None:
if self.i is None:
self.i = self.start
else:
self.i += self.step
else:
if self.i is None:
self.i = 0
else:
self.i += 1
if self.i >= self.total:
raise StopIteration()
if self.step is None:
return tracker.get_global_step()
else:
return self.i
def __len__(self) -> int:
if self.step is not None:
return (self.total - self.start) // self.step
else:
return self.total
def __contains__(self, x: object) -> bool:
return False
class TrainingLoop:
_iter: Optional[TrainingLoopIterator]
__loop: Loop
__signal_received: Optional[Tuple[Any, Any]]
def __init__(self, *,
loop_count: int,
loop_step: Optional[int],
is_save_models: bool,
log_new_line_interval: int,
log_write_interval: int,
save_models_interval: int,
is_loop_on_interrupt: bool):
self.__loop_count = loop_count
self.__loop_step = loop_step
self.__is_save_models = is_save_models
self.__log_new_line_interval = log_new_line_interval
self.__log_write_interval = log_write_interval
self.__last_write_step = 0
self.__last_new_line_step = 0
self.__save_models_interval = save_models_interval
self.__last_save_step = 0
self.__signal_received = None
self.__is_loop_on_interrupt = is_loop_on_interrupt
self._iter = None
def __iter__(self):
self._iter = TrainingLoopIterator(tracker.get_global_step(),
self.__loop_count,
self.__loop_step)
self.__loop = monit.loop(typing.cast(Collection, self._iter))
iter(self.__loop)
try:
self.old_handler = signal.signal(signal.SIGINT, self.__handler)
except ValueError:
pass
return self
@property
def idx(self):
if not self._iter:
return 0
if not self._iter.i:
return 0
if self.__loop_step is None:
return self._iter.i
return self._iter.i / self.__loop_step
def __finish(self):
try:
signal.signal(signal.SIGINT, self.old_handler)
except ValueError:
pass
tracker.save()
tracker.new_line()
if self.__is_save_models:
logger.log("Saving model...")
experiment.save_checkpoint()
# def is_interval(self, interval: int, global_step: Optional[int] = None):
# if global_step is None:
# global_step = tracker.get_global_step()
#
# if global_step - self.__loop_step < 0:
# return False
#
# if global_step // interval > (global_step - self.__loop_step) // interval:
# return True
# else:
# return False
def __next__(self):
if self.__signal_received is not None:
logger.log('\nKilling Loop.', Text.danger)
monit.finish_loop()
self.__finish()
raise StopIteration("SIGINT")
try:
global_step = next(self.__loop)
except StopIteration as e:
self.__finish()
raise e
tracker.set_global_step(global_step)
if global_step - self.__last_write_step >= self.__log_write_interval:
tracker.save()
self.__last_write_step = global_step
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
tracker.new_line()
self.__last_new_line_step = global_step
# if self.is_interval(self.__log_write_interval, global_step):
# tracker.save()
# if self.is_interval(self.__log_new_line_interval, global_step):
# logger.log()
# if (self.__is_save_models and
# self.is_interval(self.__save_models_interval, global_step)):
# experiment.save_checkpoint()
if (self.__is_save_models and
global_step - self.__last_save_step >= self.__save_models_interval):
experiment.save_checkpoint()
self.__last_save_step = global_step
return global_step
def __handler(self, sig, frame):
# Pass second interrupt without delaying
if self.__signal_received is not None:
logger.log('\nSIGINT received twice. Stopping...', Text.danger)
self.old_handler(*self.__signal_received)
return
if self.__is_loop_on_interrupt:
# Store the interrupt signal for later
self.__signal_received = (sig, frame)
logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
else:
self.__finish()
logger.log('Killing loop...', Text.danger)
self.old_handler(sig, frame)
def __str__(self):
return "LabTrainingLoop"
class TrainingLoopConfigs(BaseConfigs):
r"""
This is a configurable training loop. You can extend this class for your configurations
if it involves a training loop.
>>> for step in conf.training_loop:
>>> ...
Arguments:
loop_count (int): Total number of steps. Defaults to ``10``.
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
is_save_models (bool): Whether to call :func:`labml.experiment.save_checkpoint` on each iteration.
Defaults to ``False``.
save_models_interval (int): The interval (in steps) to save models. Defaults to ``1``.
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
Defaults to ``1``.
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
Defaults to ``1``.
is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.
Defaults to ``False``.
"""
loop_count: int = 10
loop_step: int = 1
is_save_models: bool = False
log_new_line_interval: int = 1
log_write_interval: int = 1
save_models_interval: int = 1
is_loop_on_interrupt: bool = False
training_loop: TrainingLoop
@option(TrainingLoopConfigs.training_loop)
def _loop_configs(c: TrainingLoopConfigs):
return TrainingLoop(loop_count=c.loop_count,
loop_step=c.loop_step,
is_save_models=c.is_save_models,
log_new_line_interval=c.log_new_line_interval,
log_write_interval=c.log_write_interval,
save_models_interval=c.save_models_interval,
is_loop_on_interrupt=c.is_loop_on_interrupt)
meta_config(TrainingLoopConfigs.loop_step,
TrainingLoopConfigs.loop_count,
TrainingLoopConfigs.is_save_models,
TrainingLoopConfigs.log_new_line_interval,
TrainingLoopConfigs.log_write_interval,
TrainingLoopConfigs.save_models_interval,
TrainingLoopConfigs.is_loop_on_interrupt)
class ModeState:
def __init__(self):
self._rollback_stack = []
self.is_train = False
self.is_log_activations = False
self.is_log_parameters = False
self.is_optimize = False
def _enter(self, mode: Dict[str, any]):
rollback = {}
for k, v in mode.items():
if v is None:
continue
rollback[k] = getattr(self, k)
setattr(self, k, v)
self._rollback_stack.append(rollback)
return len(self._rollback_stack)
def _exit(self, n: int):
assert n == len(self._rollback_stack)
rollback = self._rollback_stack[-1]
self._rollback_stack.pop(-1)
for k, v in rollback.items():
setattr(self, k, v)
def update(self, *,
is_train: Optional[bool] = None,
is_log_parameters: Optional[bool] = None,
is_log_activations: Optional[bool] = None,
is_optimize: Optional[bool] = None):
return Mode(self,
is_train=is_train,
is_log_parameters=is_log_parameters,
is_log_activations=is_log_activations,
is_optimize=is_optimize)
class Mode:
def __init__(self, mode: ModeState, **kwargs: any):
self.mode = mode
self.update = {}
for k, v in kwargs.items():
if v is not None:
self.update[k] = v
self.idx = -1
def __enter__(self):
self.idx = self.mode._enter(self.update)
def __exit__(self, exc_type, exc_val, exc_tb):
self.mode._exit(self.idx)
class ForwardHook:
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
self.mode = mode
self.model_name = model_name
self.name = name
self.module = module
module.register_forward_hook(self)
def save(self, name: str, output):
if isinstance(output, torch.Tensor):
pytorch_utils.store_var(name, output)
elif isinstance(output, tuple):
for i, o in enumerate(output):
self.save(f"{name}.{i}", o)
def __call__(self, module, i, o):
if not self.mode.is_log_activations:
return
self.save(f"module.{self.model_name}.{self.name}", o)
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
for name, module in model.named_modules():
if name == '':
name = 'full'
ForwardHook(mode, model_name, name, module)
class Trainer:
def __init__(self, *,
name: str,
mode: ModeState,
data_loader: torch.utils.data.DataLoader,
inner_iterations: int,
state_modules: List[StateModule],
is_track_time: bool,
step: Callable[[any, 'BatchIndex'], None]):
self.is_track_time = is_track_time
self.mode = mode
self.name = name
self.step = step
self.state_modules = state_modules
self.__iterable = None
self.__states = [sm.create_state() for sm in self.state_modules]
self.inner_iterations = inner_iterations
self.data_loader = data_loader
self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
self.data_loader = data_loader
self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
self.__iterable = None
def __call__(self):
for sm, s in zip(self.state_modules, self.__states):
sm.set_state(s)
if self.__iterable is None or self._batch_index.completed:
self.__iterable = iter(self.data_loader)
self._batch_index.reset(len(self.data_loader), self.inner_iterations)
for sm in self.state_modules:
sm.on_epoch_start()
with torch.set_grad_enabled(self.mode.is_train):
self.__iterate()
if self._batch_index.completed:
for sm in self.state_modules:
sm.on_epoch_end()
def __iterate(self):
with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
if self._batch_index.idx == 0:
monit.progress(0)
while not self._batch_index.iteration_completed:
batch = next(self.__iterable)
self.step(batch, self._batch_index)
self._batch_index.step()
monit.progress(self._batch_index.epoch_progress)
self._batch_index.step_inner()
class BatchIndex:
idx: int
total: int
iteration: int
total_iterations: int
def __init__(self, total: int, total_iterations: int):
self.total_iterations = total_iterations
self.total = total
def is_interval(self, interval: int):
if interval <= 0:
return False
if self.idx + 1 == self.total:
return True
else:
return (self.idx + 1) % interval == 0
@property
def is_last(self):
return self.idx + 1 == self.total
@property
def completed(self):
return self.iteration >= self.total_iterations
@property
def iteration_completed(self):
# // is important so that the last step happens on the last iteration
return self.idx >= (self.iteration + 1) * self.total // self.total_iterations
@property
def epoch_progress(self):
return self.idx / self.total
def step(self):
self.idx += 1
def step_inner(self):
self.iteration += 1
def reset(self, total: int, total_iterations: int):
self.total = total
self.total_iterations = total_iterations
self.idx = 0
self.iteration = 0
class TrainValidConfigs(TrainingLoopConfigs):
r"""
This is a configurable module that you can extend for experiments that involve a
training and validation datasets (i.e. most DL experiments).
Arguments:
epochs (int): Number of epochs to train on. Defaults to ``10``.
train_loader (torch.utils.data.DataLoader): Training data loader.
valid_loader (torch.utils.data.DataLoader): Training data loader.
inner_iterations (int): Number of times to switch between training and validation
within an epoch. Defaults to ``1``.
You can override ``init``, ``step`` functions. There is also a ``sample`` function
that you can override to generate samples ever time it switches between training and validation.
"""
state_modules: List[StateModule]
mode: ModeState
epochs: int = 10
trainer: Trainer
validator: Trainer
train_loader: torch.utils.data.DataLoader
valid_loader: torch.utils.data.DataLoader
loop_count = '_data_loop_count'
loop_step = None
inner_iterations: int = 1
is_track_time: bool = False
def init(self):
pass
def step(self, batch: Any, batch_idx: BatchIndex):
raise NotImplementedError
def run_step(self):
for i in range(self.inner_iterations):
with tracker.namespace('sample'):
self.sample()
with self.mode.update(is_train=True):
with tracker.namespace('train'):
self.trainer()
if self.validator:
with tracker.namespace('valid'):
self.validator()
tracker.save()
def run(self):
with monit.section("Initialize"):
self.init()
_ = self.validator
_ = self.trainer
for _ in self.training_loop:
self.run_step()
def sample(self):
pass
@option(TrainValidConfigs.trainer)
def _default_trainer(c: TrainValidConfigs):
return Trainer(name='Train',
mode=c.mode,
data_loader=c.train_loader,
inner_iterations=c.inner_iterations,
state_modules=c.state_modules,
is_track_time=c.is_track_time,
step=c.step)
@option(TrainValidConfigs.validator)
def _default_validator(c: TrainValidConfigs):
return Trainer(name='Valid',
mode=c.mode,
data_loader=c.valid_loader,
inner_iterations=c.inner_iterations,
state_modules=c.state_modules,
is_track_time=c.is_track_time,
step=c.step)
@option(TrainValidConfigs.loop_count)
def _data_loop_count(c: TrainValidConfigs):
return c.epochs
class SimpleTrainValidConfigs(TrainValidConfigs):
r"""
This is a configurable module that works for many standard DL experiments.
Arguments:
model: A PyTorch model.
optimizer: A PyTorch optimizer to update model.
device: The device to train the model on. This defaults to a configurable device
loss_function: A function to calculate the loss. This should accept ``model_output, target`` as
arguments.
update_batches (int): Number of batches to accumulate before taking an optimizer step.
Defaults to ``1``.
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
Defaults to a large number; i.e. logs every epoch.
log_activations_batches (int): How often to log model activations.
Defaults to a large number; i.e. logs every epoch.
log_save_batches (int): How often to call :func:`labml.tracker.save`.
"""
optimizer: torch.optim.Adam
model: nn.Module
device: torch.device = DeviceConfigs()
loss_func: nn.Module
update_batches: int = 1
log_params_updates: int = 2 ** 32 # 0 if not
log_activations_batches: int = 2 ** 32 # 0 if not
log_save_batches: int = 1
state_modules: List[StateModule] = []
def init(self):
pass
def step(self, batch: Any, batch_idx: BatchIndex):
self.model.train(self.mode.is_train)
data, target = batch[0].to(self.device), batch[1].to(self.device)
if self.mode.is_train:
tracker.add_global_step(len(data))
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
with monit.section("model"):
with self.mode.update(is_log_activations=is_log_activations):
output = self.model(data)
loss = self.loss_func(output, target)
tracker.add("loss.", loss)
if self.mode.is_train:
with monit.section('backward'):
loss.backward()
if batch_idx.is_interval(self.update_batches):
with monit.section('optimize'):
self.optimizer.step()
if batch_idx.is_interval(self.log_params_updates):
tracker.add('model', self.model)
self.optimizer.zero_grad()
if batch_idx.is_interval(self.log_save_batches):
tracker.save()
meta_config(SimpleTrainValidConfigs.update_batches,
SimpleTrainValidConfigs.log_params_updates,
SimpleTrainValidConfigs.log_activations_batches)
@option(SimpleTrainValidConfigs.optimizer)
def _default_optimizer(c: SimpleTrainValidConfigs):
from .optimizer import OptimizerConfigs
opt_conf = OptimizerConfigs()
opt_conf.parameters = c.model.parameters()
return opt_conf

View File

@ -19,7 +19,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
from labml import lab, monit, tracker
from labml.configs import BaseConfigs, option
from labml.utils.download import download_file
from labml_helpers.device import DeviceConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.lora.gpt2 import GPTModel

View File

@ -11,11 +11,10 @@ import torch.utils.data
from labml import experiment, tracker
from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.seed import SeedConfigs
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
from labml_nn.helpers.datasets import MNISTConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
from labml_nn.optimizers.configs import OptimizerConfigs
@ -48,7 +47,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
"""
optimizer: torch.optim.Adam
model: nn.Module
set_seed = SeedConfigs()
device: torch.device = DeviceConfigs()
epochs: int = 10
@ -126,7 +124,6 @@ def main():
# Specify the optimizer
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 1.5e-4})
conf.set_seed.set()
experiment.add_pytorch_models(dict(model=conf.model))
with experiment.start():
conf.run()

View File

@ -18,7 +18,7 @@ MyAdam...[DONE] 1,192.89ms
import torch
import torch.nn as nn
from labml_helpers.device import DeviceInfo
from labml_nn.helpers.device import DeviceInfo
from torch.optim import Adam as TorchAdam
from labml import monit

View File

@ -17,7 +17,7 @@ import torch
from labml import tracker, experiment, logger, monit
from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
from labml_helpers.schedule import Piecewise
from labml_nn.helpers.schedule import Piecewise
from labml_nn.rl.dqn import QFuncLoss
from labml_nn.rl.dqn.model import Model
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer

View File

@ -5,7 +5,7 @@ import torch
from labml import experiment, monit
from labml import logger
from labml.logger import Text
from labml_helpers.datasets.text import TextDataset
from labml_nn.helpers.datasets import TextDataset
from labml_nn.sampling import Sampler
from labml_nn.sampling.greedy import GreedySampler
from labml_nn.sampling.nucleus import NucleusSampler

View File

@ -39,9 +39,9 @@ from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from labml import lab, experiment, tracker, monit
from labml_helpers.device import DeviceConfigs
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.helpers.optimizer import OptimizerConfigs
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
from torch import optim
from torch.utils.data import Dataset, DataLoader

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from labml import experiment, tracker
from labml.configs import option, calculate
from labml_helpers.datasets.text import SequentialUnBatchedDataset
from labml_nn.helpers.datasets import SequentialUnBatchedDataset
from labml_nn.transformers.alibi import AlibiMultiHeadAttention
from labml_nn.experiments.nlp_autoregression import transpose_batch
from labml_nn.transformers import TransformerConfigs

View File

@ -13,7 +13,7 @@ on an NLP auto-regression task (with Tiny Shakespeare dataset) with [Sophia-G op
import torch
from labml import experiment, tracker
from labml_helpers.train_valid import BatchIndex
from labml_nn.helpers.trainer import BatchIndex
from labml_nn.optimizers.sophia import Sophia
from labml_nn.transformers.basic.autoregressive_experiment import Configs as TransformerAutoRegressionConfigs

View File

@ -56,7 +56,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from labml_helpers.module import TypedModuleList
from labml_nn.transformers.feed_forward import FeedForward
from labml_nn.transformers.mha import PrepareForMultiHeadAttention
from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
@ -233,7 +232,7 @@ class AttentionReconstructionLoss:
attention reconstruction loss, we detach all other parameters except $f_c$
from the gradient computation.
"""
def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
def __init__(self, layers: nn.ModuleList):
"""
`layers` is the list of Compressive Transformer layers
"""

View File

@ -16,8 +16,8 @@ import torch.nn as nn
from labml import experiment, tracker, monit, logger
from labml.configs import option
from labml.logger import Text
from labml_helpers.metrics.simple_state import SimpleStateModule
from labml_helpers.train_valid import BatchIndex, hook_model_outputs
from labml_nn.helpers.metrics import SimpleStateModule
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
CompressiveTransformerLayer, Conv1dCompression

View File

@ -16,8 +16,8 @@ from torch import nn
from labml import experiment, tracker, logger
from labml.configs import option
from labml.logger import Text
from labml_helpers.metrics.accuracy import Accuracy
from labml_helpers.train_valid import BatchIndex
from labml_nn.helpers.metrics import Accuracy
from labml_nn.helpers.trainer import BatchIndex
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers import Encoder, Generator
from labml_nn.transformers import TransformerConfigs

View File

@ -20,7 +20,7 @@ import numpy as np
import torch
from labml import lab, monit
from labml_helpers.datasets.text import TextFileDataset
from labml_nn.helpers.datasets import TextFileDataset
from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings

View File

@ -20,7 +20,7 @@ import torch
from torch.utils.data import Dataset as PyTorchDataset
from labml import lab, monit
from labml_helpers.datasets.text import TextFileDataset, TextDataset
from labml_nn.helpers.datasets import TextFileDataset, TextDataset
from labml_nn.transformers.retro.database import RetroIndex

View File

@ -17,7 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler
from labml import monit, lab, tracker, experiment, logger
from labml.logger import Text
from labml_helpers.datasets.text import TextFileDataset
from labml_nn.helpers.datasets import TextFileDataset
from labml_nn.optimizers.noam import Noam
from labml_nn.transformers.retro import model as retro
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex

View File

@ -16,7 +16,7 @@ import torch.nn as nn
from labml import experiment, tracker
from labml.configs import option
from labml_helpers.train_valid import BatchIndex
from labml_nn.helpers.trainer import BatchIndex
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs

View File

@ -16,8 +16,8 @@ from labml.logger import Text
from labml import experiment, tracker, monit, logger
from labml.configs import option
from labml_helpers.metrics.simple_state import SimpleStateModule
from labml_helpers.train_valid import BatchIndex, hook_model_outputs
from labml_nn.helpers.metrics import SimpleStateModule
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer

View File

@ -18,8 +18,8 @@ import torch.utils.data
from labml import tracker, experiment
from labml.configs import option, calculate
from labml_helpers.schedule import Schedule, RelativePiecewise
from labml_helpers.train_valid import BatchIndex
from labml_nn.helpers.schedule import Schedule, RelativePiecewise
from labml_nn.helpers.trainer import BatchIndex
from labml_nn.experiments.mnist import MNISTConfigs
from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
CrossEntropyBayesRisk, SquaredErrorBayesRisk
@ -178,7 +178,7 @@ def kl_div_coef(c: Configs):
### KL Divergence Loss Coefficient Schedule
"""
# Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise)
# Create a [relative piecewise schedule](../../helpers/schedule.html)
return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))

View File

@ -24,7 +24,7 @@ from torch import nn
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs
from labml_helpers.device import DeviceConfigs
from labml_nn.helpers.device import DeviceConfigs
from labml_nn.unet.carvana import CarvanaDataset
from labml_nn.unet import UNet
@ -34,7 +34,7 @@ class Configs(BaseConfigs):
## Configurations
"""
# Device to train the model on.
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
# [`DeviceConfigs`](../helpers/device.html)
# picks up an available CUDA device or defaults to CPU.
device: torch.device = DeviceConfigs()

View File

@ -8,19 +8,18 @@ summary: A bunch of utility functions and classes
"""
import copy
from torch import nn
from torch.utils.data import Dataset, IterableDataset
from labml_helpers.module import M, TypedModuleList
def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
def clone_module_list(module: nn.Module, n: int) -> nn.ModuleList:
"""
## Clone Module
Make a `nn.ModuleList` with clones of a given module
"""
return TypedModuleList([copy.deepcopy(module) for _ in range(n)])
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
def cycle_dataloader(data_loader):

View File

@ -21,7 +21,6 @@ setuptools.setup(
'test',
'test.*')),
install_requires=['labml==0.4.168',
'labml-helpers==0.4.89',
'torch',
'torchtext',
'torchvision',