mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 01:13:00 +08:00
remove labml_helpers dep
This commit is contained in:
@ -17,8 +17,8 @@ from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from labml import tracker, experiment
|
||||
from labml_helpers.metrics.accuracy import AccuracyDirect
|
||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||
from labml_nn.helpers.metrics import AccuracyDirect
|
||||
from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex
|
||||
from labml_nn.adaptive_computation.parity import ParityDataset
|
||||
from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLoss
|
||||
|
||||
@ -26,7 +26,7 @@ from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, Reconstruc
|
||||
class Configs(SimpleTrainValidConfigs):
|
||||
"""
|
||||
Configurations with a
|
||||
[simple training loop](https://docs.labml.ai/api/helpers.html#labml_helpers.train_valid.SimpleTrainValidConfigs)
|
||||
[simple training loop](../../helpers/trainer.html)
|
||||
"""
|
||||
|
||||
# Number of epochs
|
||||
|
@ -16,13 +16,12 @@ from typing import Any
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.metrics.accuracy import AccuracyDirect
|
||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||
from labml_nn.capsule_networks import Squash, Router, MarginLoss
|
||||
from labml_nn.helpers.datasets import MNISTConfigs
|
||||
from labml_nn.helpers.metrics import AccuracyDirect
|
||||
from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex
|
||||
|
||||
|
||||
class MNISTCapsuleNetworkModel(nn.Module):
|
||||
|
@ -26,7 +26,7 @@ from PIL import Image
|
||||
|
||||
from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs, option
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.diffusion.ddpm import DenoiseDiffusion
|
||||
from labml_nn.diffusion.ddpm.unet import UNet
|
||||
|
||||
@ -36,7 +36,7 @@ class Configs(BaseConfigs):
|
||||
## Configurations
|
||||
"""
|
||||
# Device to train the model on.
|
||||
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
|
||||
# [`DeviceConfigs`](../../device.html)
|
||||
# picks up an available CUDA device or defaults to CPU.
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
|
@ -75,7 +75,7 @@ from torch import nn
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.train_valid import BatchIndex
|
||||
from labml_nn.helpers.trainer import BatchIndex
|
||||
from labml_nn.distillation.large import LargeModel
|
||||
from labml_nn.distillation.small import SmallModel
|
||||
from labml_nn.experiments.cifar10 import CIFAR10Configs
|
||||
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
||||
|
||||
from labml import lab
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
|
||||
from labml_nn.helpers.datasets import CIFAR10Configs as CIFAR10DatasetConfigs
|
||||
from labml_nn.experiments.mnist import MNISTConfigs
|
||||
|
||||
|
||||
@ -21,8 +21,7 @@ class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
This extends from CIFAR 10 dataset configurations from
|
||||
[`labml_helpers`](https://github.com/labmlai/labml/tree/master/helpers)
|
||||
This extends from [CIFAR 10 dataset configurations](../helpers/datasets.html)
|
||||
and [`MNISTConfigs`](mnist.html).
|
||||
"""
|
||||
# Use CIFAR10 dataset by default
|
||||
|
@ -13,10 +13,10 @@ import torch.utils.data
|
||||
|
||||
from labml import tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.helpers.datasets import MNISTConfigs as MNISTDatasetConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
|
@ -17,10 +17,10 @@ from torch.utils.data import DataLoader, RandomSampler
|
||||
from labml import lab, monit, logger, tracker
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.datasets import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
|
@ -20,9 +20,9 @@ from torchtext.vocab import Vocab
|
||||
|
||||
from labml import lab, tracker, monit
|
||||
from labml.configs import option
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
|
@ -49,7 +49,7 @@ from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs
|
||||
from labml.utils.download import download_file
|
||||
from labml.utils.pytorch import get_modules
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
|
||||
|
||||
class GeneratorResNet(nn.Module):
|
||||
|
@ -16,10 +16,10 @@ from torchvision import transforms
|
||||
|
||||
from labml import tracker, monit, experiment
|
||||
from labml.configs import option, calculate
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.datasets import MNISTConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.optimizer import OptimizerConfigs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.gan.original import DiscriminatorLogitsLoss, GeneratorLogitsLoss
|
||||
|
||||
|
||||
|
@ -39,8 +39,8 @@ from PIL import Image
|
||||
|
||||
from labml import tracker, lab, monit, experiment
|
||||
from labml.configs import BaseConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.train_valid import ModeState, hook_model_outputs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.trainer import ModeState, hook_model_outputs
|
||||
from labml_nn.gan.stylegan import Discriminator, Generator, MappingNetwork, GradientPenalty, PathLengthPenalty
|
||||
from labml_nn.gan.wasserstein import DiscriminatorLoss, GeneratorLoss
|
||||
from labml_nn.utils import cycle_dataloader
|
||||
@ -88,7 +88,7 @@ class Configs(BaseConfigs):
|
||||
"""
|
||||
|
||||
# Device to train the model on.
|
||||
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
|
||||
# [`DeviceConfigs`](../../helpers/device.html)
|
||||
# picks up an available CUDA device or defaults to CPU.
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
|
@ -17,7 +17,7 @@ from torch import nn
|
||||
from labml import lab, monit, tracker, experiment
|
||||
from labml.configs import BaseConfigs, option, calculate
|
||||
from labml.utils import download
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.graphs.gat import GraphAttentionLayer
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
0
labml_nn/helpers/__init__.py
Normal file
0
labml_nn/helpers/__init__.py
Normal file
322
labml_nn/helpers/datasets.py
Normal file
322
labml_nn/helpers/datasets.py
Normal file
@ -0,0 +1,322 @@
|
||||
import random
|
||||
from pathlib import PurePath, Path
|
||||
from typing import List, Callable, Dict, Optional
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
import torch
|
||||
from labml import lab
|
||||
from labml import monit
|
||||
from labml.configs import BaseConfigs
|
||||
from labml.configs import aggregate, option
|
||||
from labml.utils.download import download_file
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
|
||||
|
||||
def _dataset(is_train, transform):
|
||||
return datasets.MNIST(str(lab.get_data_path()),
|
||||
train=is_train,
|
||||
download=True,
|
||||
transform=transform)
|
||||
|
||||
|
||||
class MNISTConfigs(BaseConfigs):
|
||||
"""
|
||||
Configurable MNIST data set.
|
||||
|
||||
Arguments:
|
||||
dataset_name (str): name of the data set, ``MNIST``
|
||||
dataset_transforms (torchvision.transforms.Compose): image transformations
|
||||
train_dataset (torchvision.datasets.MNIST): training dataset
|
||||
valid_dataset (torchvision.datasets.MNIST): validation dataset
|
||||
|
||||
train_loader (torch.utils.data.DataLoader): training data loader
|
||||
valid_loader (torch.utils.data.DataLoader): validation data loader
|
||||
|
||||
train_batch_size (int): training batch size
|
||||
valid_batch_size (int): validation batch size
|
||||
|
||||
train_loader_shuffle (bool): whether to shuffle training data
|
||||
valid_loader_shuffle (bool): whether to shuffle validation data
|
||||
"""
|
||||
|
||||
dataset_name: str = 'MNIST'
|
||||
dataset_transforms: transforms.Compose
|
||||
train_dataset: datasets.MNIST
|
||||
valid_dataset: datasets.MNIST
|
||||
|
||||
train_loader: DataLoader
|
||||
valid_loader: DataLoader
|
||||
|
||||
train_batch_size: int = 64
|
||||
valid_batch_size: int = 1024
|
||||
|
||||
train_loader_shuffle: bool = True
|
||||
valid_loader_shuffle: bool = False
|
||||
|
||||
|
||||
@option(MNISTConfigs.dataset_transforms)
|
||||
def mnist_transforms():
|
||||
return transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])
|
||||
|
||||
|
||||
@option(MNISTConfigs.train_dataset)
|
||||
def mnist_train_dataset(c: MNISTConfigs):
|
||||
return _dataset(True, c.dataset_transforms)
|
||||
|
||||
|
||||
@option(MNISTConfigs.valid_dataset)
|
||||
def mnist_valid_dataset(c: MNISTConfigs):
|
||||
return _dataset(False, c.dataset_transforms)
|
||||
|
||||
|
||||
@option(MNISTConfigs.train_loader)
|
||||
def mnist_train_loader(c: MNISTConfigs):
|
||||
return DataLoader(c.train_dataset,
|
||||
batch_size=c.train_batch_size,
|
||||
shuffle=c.train_loader_shuffle)
|
||||
|
||||
|
||||
@option(MNISTConfigs.valid_loader)
|
||||
def mnist_valid_loader(c: MNISTConfigs):
|
||||
return DataLoader(c.valid_dataset,
|
||||
batch_size=c.valid_batch_size,
|
||||
shuffle=c.valid_loader_shuffle)
|
||||
|
||||
|
||||
aggregate(MNISTConfigs.dataset_name, 'MNIST',
|
||||
(MNISTConfigs.dataset_transforms, 'mnist_transforms'),
|
||||
(MNISTConfigs.train_dataset, 'mnist_train_dataset'),
|
||||
(MNISTConfigs.valid_dataset, 'mnist_valid_dataset'),
|
||||
(MNISTConfigs.train_loader, 'mnist_train_loader'),
|
||||
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
|
||||
|
||||
|
||||
def _dataset(is_train, transform):
|
||||
return datasets.CIFAR10(str(lab.get_data_path()),
|
||||
train=is_train,
|
||||
download=True,
|
||||
transform=transform)
|
||||
|
||||
|
||||
class CIFAR10Configs(BaseConfigs):
|
||||
"""
|
||||
Configurable CIFAR 10 data set.
|
||||
|
||||
Arguments:
|
||||
dataset_name (str): name of the data set, ``CIFAR10``
|
||||
dataset_transforms (torchvision.transforms.Compose): image transformations
|
||||
train_dataset (torchvision.datasets.CIFAR10): training dataset
|
||||
valid_dataset (torchvision.datasets.CIFAR10): validation dataset
|
||||
|
||||
train_loader (torch.utils.data.DataLoader): training data loader
|
||||
valid_loader (torch.utils.data.DataLoader): validation data loader
|
||||
|
||||
train_batch_size (int): training batch size
|
||||
valid_batch_size (int): validation batch size
|
||||
|
||||
train_loader_shuffle (bool): whether to shuffle training data
|
||||
valid_loader_shuffle (bool): whether to shuffle validation data
|
||||
"""
|
||||
dataset_name: str = 'CIFAR10'
|
||||
dataset_transforms: transforms.Compose
|
||||
train_dataset: datasets.CIFAR10
|
||||
valid_dataset: datasets.CIFAR10
|
||||
|
||||
train_loader: DataLoader
|
||||
valid_loader: DataLoader
|
||||
|
||||
train_batch_size: int = 64
|
||||
valid_batch_size: int = 1024
|
||||
|
||||
train_loader_shuffle: bool = True
|
||||
valid_loader_shuffle: bool = False
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.dataset_transforms)
|
||||
def cifar10_transforms():
|
||||
return transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
|
||||
def cifar10_train_dataset(c: CIFAR10Configs):
|
||||
return _dataset(True, c.dataset_transforms)
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
|
||||
def cifar10_valid_dataset(c: CIFAR10Configs):
|
||||
return _dataset(False, c.dataset_transforms)
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
|
||||
def cifar10_train_loader(c: CIFAR10Configs):
|
||||
return DataLoader(c.train_dataset,
|
||||
batch_size=c.train_batch_size,
|
||||
shuffle=c.train_loader_shuffle)
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.valid_loader)
|
||||
def cifar10_valid_loader(c: CIFAR10Configs):
|
||||
return DataLoader(c.valid_dataset,
|
||||
batch_size=c.valid_batch_size,
|
||||
shuffle=c.valid_loader_shuffle)
|
||||
|
||||
|
||||
CIFAR10Configs.aggregate(CIFAR10Configs.dataset_name, 'CIFAR10',
|
||||
(CIFAR10Configs.dataset_transforms, 'cifar10_transforms'),
|
||||
(CIFAR10Configs.train_dataset, 'cifar10_train_dataset'),
|
||||
(CIFAR10Configs.valid_dataset, 'cifar10_valid_dataset'),
|
||||
(CIFAR10Configs.train_loader, 'cifar10_train_loader'),
|
||||
(CIFAR10Configs.valid_loader, 'cifar10_valid_loader'))
|
||||
|
||||
|
||||
class TextDataset:
|
||||
itos: List[str]
|
||||
stoi: Dict[str, int]
|
||||
n_tokens: int
|
||||
train: str
|
||||
valid: str
|
||||
standard_tokens: List[str] = []
|
||||
|
||||
@staticmethod
|
||||
def load(path: PurePath):
|
||||
with open(str(path), 'r') as f:
|
||||
return f.read()
|
||||
|
||||
def __init__(self, path: PurePath, tokenizer: Callable, train: str, valid: str, test: str, *,
|
||||
n_tokens: Optional[int] = None,
|
||||
stoi: Optional[Dict[str, int]] = None,
|
||||
itos: Optional[List[str]] = None):
|
||||
self.test = test
|
||||
self.valid = valid
|
||||
self.train = train
|
||||
self.tokenizer = tokenizer
|
||||
self.path = path
|
||||
|
||||
if n_tokens or stoi or itos:
|
||||
assert stoi and itos and n_tokens
|
||||
self.n_tokens = n_tokens
|
||||
self.stoi = stoi
|
||||
self.itos = itos
|
||||
else:
|
||||
self.n_tokens = len(self.standard_tokens)
|
||||
self.stoi = {t: i for i, t in enumerate(self.standard_tokens)}
|
||||
|
||||
with monit.section("Tokenize"):
|
||||
tokens = self.tokenizer(self.train) + self.tokenizer(self.valid)
|
||||
tokens = sorted(list(set(tokens)))
|
||||
|
||||
for t in monit.iterate("Build vocabulary", tokens):
|
||||
self.stoi[t] = self.n_tokens
|
||||
self.n_tokens += 1
|
||||
|
||||
self.itos = [''] * self.n_tokens
|
||||
for t, n in self.stoi.items():
|
||||
self.itos[n] = t
|
||||
|
||||
def text_to_i(self, text: str) -> torch.Tensor:
|
||||
tokens = self.tokenizer(text)
|
||||
return torch.tensor([self.stoi[s] for s in tokens if s in self.stoi], dtype=torch.long)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{len(self.train) / 1_000_000 :,.2f}M, {len(self.valid) / 1_000_000 :,.2f}M - {str(self.path)}'
|
||||
|
||||
|
||||
class SequentialDataLoader(IterableDataset):
|
||||
def __init__(self, *, text: str, dataset: TextDataset,
|
||||
batch_size: int, seq_len: int):
|
||||
self.seq_len = seq_len
|
||||
data = dataset.text_to_i(text)
|
||||
n_batch = data.shape[0] // batch_size
|
||||
data = data.narrow(0, 0, n_batch * batch_size)
|
||||
data = data.view(batch_size, -1).t().contiguous()
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return self.data.shape[0] // self.seq_len
|
||||
|
||||
def __iter__(self):
|
||||
self.idx = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.idx >= self.data.shape[0] - 1:
|
||||
raise StopIteration()
|
||||
|
||||
seq_len = min(self.seq_len, self.data.shape[0] - 1 - self.idx)
|
||||
i = self.idx + seq_len
|
||||
data = self.data[self.idx: i]
|
||||
target = self.data[self.idx + 1: i + 1]
|
||||
self.idx = i
|
||||
return data, target
|
||||
|
||||
def __getitem__(self, idx):
|
||||
seq_len = min(self.seq_len, self.data.shape[0] - 1 - idx)
|
||||
i = idx + seq_len
|
||||
data = self.data[idx: i]
|
||||
target = self.data[idx + 1: i + 1]
|
||||
return data, target
|
||||
|
||||
|
||||
class SequentialUnBatchedDataset(Dataset):
|
||||
def __init__(self, *, text: str, dataset: TextDataset,
|
||||
seq_len: int,
|
||||
is_random_offset: bool = True):
|
||||
self.is_random_offset = is_random_offset
|
||||
self.seq_len = seq_len
|
||||
self.data = dataset.text_to_i(text)
|
||||
|
||||
def __len__(self):
|
||||
return (self.data.shape[0] - 1) // self.seq_len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
start = idx * self.seq_len
|
||||
assert start + self.seq_len + 1 <= self.data.shape[0]
|
||||
if self.is_random_offset:
|
||||
start += random.randint(0, min(self.seq_len - 1, self.data.shape[0] - (start + self.seq_len + 1)))
|
||||
|
||||
end = start + self.seq_len
|
||||
data = self.data[start: end]
|
||||
target = self.data[start + 1: end + 1]
|
||||
return data, target
|
||||
|
||||
|
||||
class TextFileDataset(TextDataset):
|
||||
standard_tokens = []
|
||||
|
||||
def __init__(self, path: PurePath, tokenizer: Callable, *,
|
||||
url: Optional[str] = None,
|
||||
filter_subset: Optional[int] = None):
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
if not url:
|
||||
raise FileNotFoundError(str(path))
|
||||
else:
|
||||
download_file(url, path)
|
||||
|
||||
with monit.section("Load data"):
|
||||
text = self.load(path)
|
||||
if filter_subset:
|
||||
text = text[:filter_subset]
|
||||
split = int(len(text) * .9)
|
||||
train = text[:split]
|
||||
valid = text[split:]
|
||||
|
||||
super().__init__(path, tokenizer, train, valid, '')
|
||||
|
||||
|
||||
def _test_tiny_shakespeare():
|
||||
from labml import lab
|
||||
_ = TextFileDataset(lab.get_data_path() / 'tiny_shakespeare.txt', lambda x: list(x),
|
||||
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_tiny_shakespeare()
|
70
labml_nn/helpers/device.py
Normal file
70
labml_nn/helpers/device.py
Normal file
@ -0,0 +1,70 @@
|
||||
import torch
|
||||
|
||||
from labml.configs import BaseConfigs, hyperparams, option
|
||||
|
||||
|
||||
class DeviceInfo:
|
||||
def __init__(self, *,
|
||||
use_cuda: bool,
|
||||
cuda_device: int):
|
||||
self.use_cuda = use_cuda
|
||||
self.cuda_device = cuda_device
|
||||
self.cuda_count = torch.cuda.device_count()
|
||||
|
||||
self.is_cuda = self.use_cuda and torch.cuda.is_available()
|
||||
if not self.is_cuda:
|
||||
self.device = torch.device('cpu')
|
||||
else:
|
||||
if self.cuda_device < self.cuda_count:
|
||||
self.device = torch.device('cuda', self.cuda_device)
|
||||
else:
|
||||
self.device = torch.device('cuda', self.cuda_count - 1)
|
||||
|
||||
def __str__(self):
|
||||
if not self.is_cuda:
|
||||
return "CPU"
|
||||
|
||||
if self.cuda_device < self.cuda_count:
|
||||
return f"GPU:{self.cuda_device} - {torch.cuda.get_device_name(self.cuda_device)}"
|
||||
else:
|
||||
return (f"GPU:{self.cuda_count - 1}({self.cuda_device}) "
|
||||
f"- {torch.cuda.get_device_name(self.cuda_count - 1)}")
|
||||
|
||||
|
||||
class DeviceConfigs(BaseConfigs):
|
||||
r"""
|
||||
This is a configurable module to get a single device to train model on.
|
||||
It can pick up CUDA devices and it will fall back to CPU if they are not available.
|
||||
|
||||
It has other small advantages such as being able to view the
|
||||
actual device name on configurations view of
|
||||
`labml app <https://github.com/labmlai/labml/tree/master/app>`_
|
||||
|
||||
Arguments:
|
||||
cuda_device (int): The CUDA device number. Defaults to ``0``.
|
||||
use_cuda (bool): Whether to use CUDA devices. Defaults to ``True``.
|
||||
"""
|
||||
cuda_device: int = 0
|
||||
use_cuda: bool = True
|
||||
|
||||
device_info: DeviceInfo
|
||||
|
||||
device: torch.device
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(_primary='device')
|
||||
|
||||
|
||||
@option(DeviceConfigs.device)
|
||||
def _device(c: DeviceConfigs):
|
||||
return c.device_info.device
|
||||
|
||||
|
||||
hyperparams(DeviceConfigs.cuda_device, DeviceConfigs.use_cuda,
|
||||
is_hyperparam=False)
|
||||
|
||||
|
||||
@option(DeviceConfigs.device_info)
|
||||
def _device_info(c: DeviceConfigs):
|
||||
return DeviceInfo(use_cuda=c.use_cuda,
|
||||
cuda_device=c.cuda_device)
|
122
labml_nn/helpers/metrics.py
Normal file
122
labml_nn/helpers/metrics.py
Normal file
@ -0,0 +1,122 @@
|
||||
import dataclasses
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
from labml import tracker
|
||||
|
||||
|
||||
class StateModule:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# def __call__(self):
|
||||
# raise NotImplementedError
|
||||
|
||||
def create_state(self) -> any:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_state(self, data: any):
|
||||
raise NotImplementedError
|
||||
|
||||
def on_epoch_start(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def on_epoch_end(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Metric(StateModule, ABC):
|
||||
def track(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AccuracyState:
|
||||
samples: int = 0
|
||||
correct: int = 0
|
||||
|
||||
def reset(self):
|
||||
self.samples = 0
|
||||
self.correct = 0
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
data: AccuracyState
|
||||
|
||||
def __init__(self, ignore_index: int = -1):
|
||||
super().__init__()
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
||||
output = output.view(-1, output.shape[-1])
|
||||
target = target.view(-1)
|
||||
pred = output.argmax(dim=-1)
|
||||
mask = target == self.ignore_index
|
||||
pred.masked_fill_(mask, self.ignore_index)
|
||||
n_masked = mask.sum().item()
|
||||
self.data.correct += pred.eq(target).sum().item() - n_masked
|
||||
self.data.samples += len(target) - n_masked
|
||||
|
||||
def create_state(self):
|
||||
return AccuracyState()
|
||||
|
||||
def set_state(self, data: any):
|
||||
self.data = data
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.data.reset()
|
||||
|
||||
def on_epoch_end(self):
|
||||
self.track()
|
||||
|
||||
def track(self):
|
||||
if self.data.samples == 0:
|
||||
return
|
||||
tracker.add("accuracy.", self.data.correct / self.data.samples)
|
||||
|
||||
|
||||
class AccuracyMovingAvg(Metric):
|
||||
def __init__(self, ignore_index: int = -1, queue_size: int = 5):
|
||||
super().__init__()
|
||||
self.ignore_index = ignore_index
|
||||
tracker.set_queue('accuracy.*', queue_size, is_print=True)
|
||||
|
||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
||||
output = output.view(-1, output.shape[-1])
|
||||
target = target.view(-1)
|
||||
pred = output.argmax(dim=-1)
|
||||
mask = target == self.ignore_index
|
||||
pred.masked_fill_(mask, self.ignore_index)
|
||||
n_masked = mask.sum().item()
|
||||
if len(target) - n_masked > 0:
|
||||
tracker.add('accuracy.', (pred.eq(target).sum().item() - n_masked) / (len(target) - n_masked))
|
||||
|
||||
def create_state(self):
|
||||
return None
|
||||
|
||||
def set_state(self, data: any):
|
||||
pass
|
||||
|
||||
def on_epoch_start(self):
|
||||
pass
|
||||
|
||||
def on_epoch_end(self):
|
||||
pass
|
||||
|
||||
|
||||
class BinaryAccuracy(Accuracy):
|
||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
||||
pred = output.view(-1) > 0
|
||||
target = target.view(-1)
|
||||
self.data.correct += pred.eq(target).sum().item()
|
||||
self.data.samples += len(target)
|
||||
|
||||
|
||||
class AccuracyDirect(Accuracy):
|
||||
data: AccuracyState
|
||||
|
||||
def __call__(self, output: torch.Tensor, target: torch.Tensor):
|
||||
output = output.view(-1)
|
||||
target = target.view(-1)
|
||||
self.data.correct += output.eq(target).sum().item()
|
||||
self.data.samples += len(target)
|
97
labml_nn/helpers/optimizer.py
Normal file
97
labml_nn/helpers/optimizer.py
Normal file
@ -0,0 +1,97 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from labml import tracker
|
||||
|
||||
from labml.configs import BaseConfigs, option, meta_config
|
||||
|
||||
|
||||
class OptimizerConfigs(BaseConfigs):
|
||||
r"""
|
||||
This creates a configurable optimizer.
|
||||
|
||||
Arguments:
|
||||
learning_rate (float): Learning rate of the optimizer. Defaults to ``0.01``.
|
||||
momentum (float): Momentum of the optimizer. Defaults to ``0.5``.
|
||||
parameters: Model parameters to optimize.
|
||||
d_model (int): Embedding size of the model (for Noam optimizer).
|
||||
betas (Tuple[float, float]): Betas for Adam optimizer. Defaults to ``(0.9, 0.999)``.
|
||||
eps (float): Epsilon for Adam/RMSProp optimizers. Defaults to ``1e-8``.
|
||||
step_factor (int): Step factor for Noam optimizer. Defaults to ``1024``.
|
||||
|
||||
Also there is a better (more options) implementation in ``labml_nn``.
|
||||
`We recommend using that <https://nn.labml.ai/optimizers/configs.html>`_.
|
||||
"""
|
||||
|
||||
optimizer: torch.optim.Adam
|
||||
learning_rate: float = 0.01
|
||||
momentum: float = 0.5
|
||||
parameters: any
|
||||
d_model: int
|
||||
betas: Tuple[float, float] = (0.9, 0.999)
|
||||
eps: float = 1e-8
|
||||
step_factor: int = 1024
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(_primary='optimizer')
|
||||
|
||||
|
||||
meta_config(OptimizerConfigs.parameters)
|
||||
|
||||
|
||||
@option(OptimizerConfigs.optimizer, 'SGD')
|
||||
def sgd_optimizer(c: OptimizerConfigs):
|
||||
return torch.optim.SGD(c.parameters, c.learning_rate, c.momentum)
|
||||
|
||||
|
||||
@option(OptimizerConfigs.optimizer, 'Adam')
|
||||
def adam_optimizer(c: OptimizerConfigs):
|
||||
return torch.optim.Adam(c.parameters, lr=c.learning_rate,
|
||||
betas=c.betas, eps=c.eps)
|
||||
|
||||
|
||||
class NoamOpt:
|
||||
def __init__(self, model_size: int, learning_rate: float, warmup: int, step_factor: int, optimizer):
|
||||
self.step_factor = step_factor
|
||||
self.optimizer = optimizer
|
||||
self.warmup = warmup
|
||||
self.learning_rate = learning_rate
|
||||
self.model_size = model_size
|
||||
self._rate = 0
|
||||
|
||||
def step(self):
|
||||
rate = self.rate(tracker.get_global_step() / self.step_factor)
|
||||
for p in self.optimizer.param_groups:
|
||||
p['lr'] = rate
|
||||
self._rate = rate
|
||||
self.optimizer.step()
|
||||
|
||||
def rate(self, step):
|
||||
factor = self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))
|
||||
return self.learning_rate * factor
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
|
||||
@option(OptimizerConfigs.optimizer, 'Noam')
|
||||
def noam_optimizer(c: OptimizerConfigs):
|
||||
optimizer = torch.optim.Adam(c.parameters, lr=0.0, betas=c.betas, eps=c.eps)
|
||||
return NoamOpt(c.d_model, 1, 2000, c.step_factor, optimizer)
|
||||
|
||||
|
||||
def _test_noam_optimizer():
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
opts = [NoamOpt(512, 1, 4000, None),
|
||||
NoamOpt(512, 1, 8000, None),
|
||||
NoamOpt(2048, 1, 2000, None)]
|
||||
plt.plot(np.arange(1, 20000), [[opt.rate(i) for opt in opts] for i in range(1, 20000)])
|
||||
plt.legend(["512:4000", "512:8000", "256:4000"])
|
||||
plt.title("Optimizer")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_noam_optimizer()
|
84
labml_nn/helpers/schedule.py
Normal file
84
labml_nn/helpers/schedule.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
class Schedule:
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Flat(Schedule):
|
||||
def __init__(self, value):
|
||||
self.__value = value
|
||||
|
||||
def __call__(self, x):
|
||||
return self.__value
|
||||
|
||||
def __str__(self):
|
||||
return f"Schedule({self.__value})"
|
||||
|
||||
|
||||
class Dynamic(Schedule):
|
||||
def __init__(self, value):
|
||||
self.__value = value
|
||||
|
||||
def __call__(self, x):
|
||||
return self.__value
|
||||
|
||||
def update(self, value):
|
||||
self.__value = value
|
||||
|
||||
def __str__(self):
|
||||
return "Dynamic"
|
||||
|
||||
|
||||
class Piecewise(Schedule):
|
||||
"""
|
||||
## Piecewise schedule
|
||||
"""
|
||||
|
||||
def __init__(self, endpoints: List[Tuple[float, float]], outside_value: float = None):
|
||||
"""
|
||||
### Initialize
|
||||
|
||||
`endpoints` is list of pairs `(x, y)`.
|
||||
The values between endpoints are linearly interpolated.
|
||||
`y` values outside the range covered by `x` are
|
||||
`outside_value`.
|
||||
"""
|
||||
|
||||
# `(x, y)` pairs should be sorted
|
||||
indexes = [e[0] for e in endpoints]
|
||||
assert indexes == sorted(indexes)
|
||||
|
||||
self._outside_value = outside_value
|
||||
self._endpoints = endpoints
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
### Find `y` for given `x`
|
||||
"""
|
||||
|
||||
# iterate through each segment
|
||||
for (x1, y1), (x2, y2) in zip(self._endpoints[:-1], self._endpoints[1:]):
|
||||
# interpolate if `x` is within the segment
|
||||
if x1 <= x < x2:
|
||||
dx = float(x - x1) / (x2 - x1)
|
||||
return y1 + dx * (y2 - y1)
|
||||
|
||||
# return outside value otherwise
|
||||
return self._outside_value
|
||||
|
||||
def __str__(self):
|
||||
endpoints = ", ".join([f"({e[0]}, {e[1]})" for e in self._endpoints])
|
||||
return f"Schedule[{endpoints}, {self._outside_value}]"
|
||||
|
||||
|
||||
class RelativePiecewise(Piecewise):
|
||||
def __init__(self, relative_endpoits: List[Tuple[float, float]], total_steps: int):
|
||||
endpoints = []
|
||||
for e in relative_endpoits:
|
||||
index = int(total_steps * e[0])
|
||||
assert index >= 0
|
||||
endpoints.append((index, e[1]))
|
||||
|
||||
super().__init__(endpoints, outside_value=relative_endpoits[-1][1])
|
598
labml_nn/helpers/trainer.py
Normal file
598
labml_nn/helpers/trainer.py
Normal file
@ -0,0 +1,598 @@
|
||||
import signal
|
||||
import typing
|
||||
from typing import Dict, List, Callable
|
||||
from typing import Optional, Tuple, Any, Collection
|
||||
|
||||
import labml.utils.pytorch as pytorch_utils
|
||||
import torch.optim
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data
|
||||
from labml import tracker, logger, experiment, monit
|
||||
from labml.configs import BaseConfigs, meta_config, option
|
||||
from labml.internal.monitor import Loop
|
||||
from labml.logger import Text
|
||||
from torch import nn
|
||||
from .device import DeviceConfigs
|
||||
from .metrics import StateModule
|
||||
|
||||
|
||||
class TrainingLoopIterator(Collection):
|
||||
def __init__(self, start: int, total: int, step: Optional[int]):
|
||||
self.step = step
|
||||
self.total = total
|
||||
self.start = start
|
||||
self.i = None
|
||||
|
||||
def __iter__(self):
|
||||
self.i = None
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.step is not None:
|
||||
if self.i is None:
|
||||
self.i = self.start
|
||||
else:
|
||||
self.i += self.step
|
||||
else:
|
||||
if self.i is None:
|
||||
self.i = 0
|
||||
else:
|
||||
self.i += 1
|
||||
|
||||
if self.i >= self.total:
|
||||
raise StopIteration()
|
||||
|
||||
if self.step is None:
|
||||
return tracker.get_global_step()
|
||||
else:
|
||||
return self.i
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.step is not None:
|
||||
return (self.total - self.start) // self.step
|
||||
else:
|
||||
return self.total
|
||||
|
||||
def __contains__(self, x: object) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TrainingLoop:
|
||||
_iter: Optional[TrainingLoopIterator]
|
||||
__loop: Loop
|
||||
__signal_received: Optional[Tuple[Any, Any]]
|
||||
|
||||
def __init__(self, *,
|
||||
loop_count: int,
|
||||
loop_step: Optional[int],
|
||||
is_save_models: bool,
|
||||
log_new_line_interval: int,
|
||||
log_write_interval: int,
|
||||
save_models_interval: int,
|
||||
is_loop_on_interrupt: bool):
|
||||
self.__loop_count = loop_count
|
||||
self.__loop_step = loop_step
|
||||
self.__is_save_models = is_save_models
|
||||
self.__log_new_line_interval = log_new_line_interval
|
||||
self.__log_write_interval = log_write_interval
|
||||
self.__last_write_step = 0
|
||||
self.__last_new_line_step = 0
|
||||
self.__save_models_interval = save_models_interval
|
||||
self.__last_save_step = 0
|
||||
self.__signal_received = None
|
||||
self.__is_loop_on_interrupt = is_loop_on_interrupt
|
||||
self._iter = None
|
||||
|
||||
def __iter__(self):
|
||||
self._iter = TrainingLoopIterator(tracker.get_global_step(),
|
||||
self.__loop_count,
|
||||
self.__loop_step)
|
||||
|
||||
self.__loop = monit.loop(typing.cast(Collection, self._iter))
|
||||
|
||||
iter(self.__loop)
|
||||
try:
|
||||
self.old_handler = signal.signal(signal.SIGINT, self.__handler)
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
@property
|
||||
def idx(self):
|
||||
if not self._iter:
|
||||
return 0
|
||||
if not self._iter.i:
|
||||
return 0
|
||||
if self.__loop_step is None:
|
||||
return self._iter.i
|
||||
return self._iter.i / self.__loop_step
|
||||
|
||||
def __finish(self):
|
||||
try:
|
||||
signal.signal(signal.SIGINT, self.old_handler)
|
||||
except ValueError:
|
||||
pass
|
||||
tracker.save()
|
||||
tracker.new_line()
|
||||
if self.__is_save_models:
|
||||
logger.log("Saving model...")
|
||||
experiment.save_checkpoint()
|
||||
|
||||
# def is_interval(self, interval: int, global_step: Optional[int] = None):
|
||||
# if global_step is None:
|
||||
# global_step = tracker.get_global_step()
|
||||
#
|
||||
# if global_step - self.__loop_step < 0:
|
||||
# return False
|
||||
#
|
||||
# if global_step // interval > (global_step - self.__loop_step) // interval:
|
||||
# return True
|
||||
# else:
|
||||
# return False
|
||||
|
||||
def __next__(self):
|
||||
if self.__signal_received is not None:
|
||||
logger.log('\nKilling Loop.', Text.danger)
|
||||
monit.finish_loop()
|
||||
self.__finish()
|
||||
raise StopIteration("SIGINT")
|
||||
|
||||
try:
|
||||
global_step = next(self.__loop)
|
||||
except StopIteration as e:
|
||||
self.__finish()
|
||||
raise e
|
||||
|
||||
tracker.set_global_step(global_step)
|
||||
|
||||
if global_step - self.__last_write_step >= self.__log_write_interval:
|
||||
tracker.save()
|
||||
self.__last_write_step = global_step
|
||||
if global_step - self.__last_new_line_step >= self.__log_new_line_interval:
|
||||
tracker.new_line()
|
||||
self.__last_new_line_step = global_step
|
||||
# if self.is_interval(self.__log_write_interval, global_step):
|
||||
# tracker.save()
|
||||
# if self.is_interval(self.__log_new_line_interval, global_step):
|
||||
# logger.log()
|
||||
|
||||
# if (self.__is_save_models and
|
||||
# self.is_interval(self.__save_models_interval, global_step)):
|
||||
# experiment.save_checkpoint()
|
||||
if (self.__is_save_models and
|
||||
global_step - self.__last_save_step >= self.__save_models_interval):
|
||||
experiment.save_checkpoint()
|
||||
self.__last_save_step = global_step
|
||||
|
||||
return global_step
|
||||
|
||||
def __handler(self, sig, frame):
|
||||
# Pass second interrupt without delaying
|
||||
if self.__signal_received is not None:
|
||||
logger.log('\nSIGINT received twice. Stopping...', Text.danger)
|
||||
self.old_handler(*self.__signal_received)
|
||||
return
|
||||
|
||||
if self.__is_loop_on_interrupt:
|
||||
# Store the interrupt signal for later
|
||||
self.__signal_received = (sig, frame)
|
||||
logger.log('\nSIGINT received. Delaying KeyboardInterrupt.', Text.danger)
|
||||
else:
|
||||
self.__finish()
|
||||
logger.log('Killing loop...', Text.danger)
|
||||
self.old_handler(sig, frame)
|
||||
|
||||
def __str__(self):
|
||||
return "LabTrainingLoop"
|
||||
|
||||
|
||||
class TrainingLoopConfigs(BaseConfigs):
|
||||
r"""
|
||||
This is a configurable training loop. You can extend this class for your configurations
|
||||
if it involves a training loop.
|
||||
|
||||
>>> for step in conf.training_loop:
|
||||
>>> ...
|
||||
|
||||
Arguments:
|
||||
loop_count (int): Total number of steps. Defaults to ``10``.
|
||||
loop_step (int): Number of steps to increment per iteration. Defaults to ``1``.
|
||||
is_save_models (bool): Whether to call :func:`labml.experiment.save_checkpoint` on each iteration.
|
||||
Defaults to ``False``.
|
||||
save_models_interval (int): The interval (in steps) to save models. Defaults to ``1``.
|
||||
log_new_line_interval (int): The interval (in steps) to print a new line to the screen.
|
||||
Defaults to ``1``.
|
||||
log_write_interval (int): The interval (in steps) to call :func:`labml.tracker.save`.
|
||||
Defaults to ``1``.
|
||||
is_loop_on_interrupt (bool): Whether to handle keyboard interrupts and wait until a iteration is complete.
|
||||
Defaults to ``False``.
|
||||
"""
|
||||
loop_count: int = 10
|
||||
loop_step: int = 1
|
||||
is_save_models: bool = False
|
||||
log_new_line_interval: int = 1
|
||||
log_write_interval: int = 1
|
||||
save_models_interval: int = 1
|
||||
is_loop_on_interrupt: bool = False
|
||||
|
||||
training_loop: TrainingLoop
|
||||
|
||||
|
||||
@option(TrainingLoopConfigs.training_loop)
|
||||
def _loop_configs(c: TrainingLoopConfigs):
|
||||
return TrainingLoop(loop_count=c.loop_count,
|
||||
loop_step=c.loop_step,
|
||||
is_save_models=c.is_save_models,
|
||||
log_new_line_interval=c.log_new_line_interval,
|
||||
log_write_interval=c.log_write_interval,
|
||||
save_models_interval=c.save_models_interval,
|
||||
is_loop_on_interrupt=c.is_loop_on_interrupt)
|
||||
|
||||
|
||||
meta_config(TrainingLoopConfigs.loop_step,
|
||||
TrainingLoopConfigs.loop_count,
|
||||
TrainingLoopConfigs.is_save_models,
|
||||
TrainingLoopConfigs.log_new_line_interval,
|
||||
TrainingLoopConfigs.log_write_interval,
|
||||
TrainingLoopConfigs.save_models_interval,
|
||||
TrainingLoopConfigs.is_loop_on_interrupt)
|
||||
|
||||
|
||||
class ModeState:
|
||||
def __init__(self):
|
||||
self._rollback_stack = []
|
||||
|
||||
self.is_train = False
|
||||
self.is_log_activations = False
|
||||
self.is_log_parameters = False
|
||||
self.is_optimize = False
|
||||
|
||||
def _enter(self, mode: Dict[str, any]):
|
||||
rollback = {}
|
||||
for k, v in mode.items():
|
||||
if v is None:
|
||||
continue
|
||||
rollback[k] = getattr(self, k)
|
||||
setattr(self, k, v)
|
||||
|
||||
self._rollback_stack.append(rollback)
|
||||
|
||||
return len(self._rollback_stack)
|
||||
|
||||
def _exit(self, n: int):
|
||||
assert n == len(self._rollback_stack)
|
||||
|
||||
rollback = self._rollback_stack[-1]
|
||||
self._rollback_stack.pop(-1)
|
||||
|
||||
for k, v in rollback.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def update(self, *,
|
||||
is_train: Optional[bool] = None,
|
||||
is_log_parameters: Optional[bool] = None,
|
||||
is_log_activations: Optional[bool] = None,
|
||||
is_optimize: Optional[bool] = None):
|
||||
return Mode(self,
|
||||
is_train=is_train,
|
||||
is_log_parameters=is_log_parameters,
|
||||
is_log_activations=is_log_activations,
|
||||
is_optimize=is_optimize)
|
||||
|
||||
|
||||
class Mode:
|
||||
def __init__(self, mode: ModeState, **kwargs: any):
|
||||
self.mode = mode
|
||||
self.update = {}
|
||||
for k, v in kwargs.items():
|
||||
if v is not None:
|
||||
self.update[k] = v
|
||||
|
||||
self.idx = -1
|
||||
|
||||
def __enter__(self):
|
||||
self.idx = self.mode._enter(self.update)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.mode._exit(self.idx)
|
||||
|
||||
|
||||
class ForwardHook:
|
||||
def __init__(self, mode: ModeState, model_name, name: str, module: torch.nn.Module):
|
||||
self.mode = mode
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
self.module = module
|
||||
module.register_forward_hook(self)
|
||||
|
||||
def save(self, name: str, output):
|
||||
if isinstance(output, torch.Tensor):
|
||||
pytorch_utils.store_var(name, output)
|
||||
elif isinstance(output, tuple):
|
||||
for i, o in enumerate(output):
|
||||
self.save(f"{name}.{i}", o)
|
||||
|
||||
def __call__(self, module, i, o):
|
||||
if not self.mode.is_log_activations:
|
||||
return
|
||||
|
||||
self.save(f"module.{self.model_name}.{self.name}", o)
|
||||
|
||||
|
||||
def hook_model_outputs(mode: ModeState, model: torch.nn.Module, model_name: str = "model"):
|
||||
for name, module in model.named_modules():
|
||||
if name == '':
|
||||
name = 'full'
|
||||
ForwardHook(mode, model_name, name, module)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, *,
|
||||
name: str,
|
||||
mode: ModeState,
|
||||
data_loader: torch.utils.data.DataLoader,
|
||||
inner_iterations: int,
|
||||
state_modules: List[StateModule],
|
||||
is_track_time: bool,
|
||||
step: Callable[[any, 'BatchIndex'], None]):
|
||||
self.is_track_time = is_track_time
|
||||
self.mode = mode
|
||||
self.name = name
|
||||
self.step = step
|
||||
self.state_modules = state_modules
|
||||
self.__iterable = None
|
||||
self.__states = [sm.create_state() for sm in self.state_modules]
|
||||
self.inner_iterations = inner_iterations
|
||||
self.data_loader = data_loader
|
||||
self._batch_index = BatchIndex(len(self.data_loader), self.inner_iterations)
|
||||
|
||||
def set_data_loader(self, data_loader: torch.utils.data.DataLoader):
|
||||
self.data_loader = data_loader
|
||||
self._batch_index = BatchIndex(len(data_loader), self.inner_iterations)
|
||||
self.__iterable = None
|
||||
|
||||
def __call__(self):
|
||||
for sm, s in zip(self.state_modules, self.__states):
|
||||
sm.set_state(s)
|
||||
|
||||
if self.__iterable is None or self._batch_index.completed:
|
||||
self.__iterable = iter(self.data_loader)
|
||||
self._batch_index.reset(len(self.data_loader), self.inner_iterations)
|
||||
for sm in self.state_modules:
|
||||
sm.on_epoch_start()
|
||||
with torch.set_grad_enabled(self.mode.is_train):
|
||||
self.__iterate()
|
||||
|
||||
if self._batch_index.completed:
|
||||
for sm in self.state_modules:
|
||||
sm.on_epoch_end()
|
||||
|
||||
def __iterate(self):
|
||||
with monit.section(self.name, is_partial=True, is_track=self.is_track_time):
|
||||
if self._batch_index.idx == 0:
|
||||
monit.progress(0)
|
||||
while not self._batch_index.iteration_completed:
|
||||
batch = next(self.__iterable)
|
||||
|
||||
self.step(batch, self._batch_index)
|
||||
|
||||
self._batch_index.step()
|
||||
monit.progress(self._batch_index.epoch_progress)
|
||||
|
||||
self._batch_index.step_inner()
|
||||
|
||||
|
||||
class BatchIndex:
|
||||
idx: int
|
||||
total: int
|
||||
iteration: int
|
||||
total_iterations: int
|
||||
|
||||
def __init__(self, total: int, total_iterations: int):
|
||||
self.total_iterations = total_iterations
|
||||
self.total = total
|
||||
|
||||
def is_interval(self, interval: int):
|
||||
if interval <= 0:
|
||||
return False
|
||||
if self.idx + 1 == self.total:
|
||||
return True
|
||||
else:
|
||||
return (self.idx + 1) % interval == 0
|
||||
|
||||
@property
|
||||
def is_last(self):
|
||||
return self.idx + 1 == self.total
|
||||
|
||||
@property
|
||||
def completed(self):
|
||||
return self.iteration >= self.total_iterations
|
||||
|
||||
@property
|
||||
def iteration_completed(self):
|
||||
# // is important so that the last step happens on the last iteration
|
||||
return self.idx >= (self.iteration + 1) * self.total // self.total_iterations
|
||||
|
||||
@property
|
||||
def epoch_progress(self):
|
||||
return self.idx / self.total
|
||||
|
||||
def step(self):
|
||||
self.idx += 1
|
||||
|
||||
def step_inner(self):
|
||||
self.iteration += 1
|
||||
|
||||
def reset(self, total: int, total_iterations: int):
|
||||
self.total = total
|
||||
self.total_iterations = total_iterations
|
||||
self.idx = 0
|
||||
self.iteration = 0
|
||||
|
||||
|
||||
class TrainValidConfigs(TrainingLoopConfigs):
|
||||
r"""
|
||||
This is a configurable module that you can extend for experiments that involve a
|
||||
training and validation datasets (i.e. most DL experiments).
|
||||
|
||||
Arguments:
|
||||
epochs (int): Number of epochs to train on. Defaults to ``10``.
|
||||
train_loader (torch.utils.data.DataLoader): Training data loader.
|
||||
valid_loader (torch.utils.data.DataLoader): Training data loader.
|
||||
inner_iterations (int): Number of times to switch between training and validation
|
||||
within an epoch. Defaults to ``1``.
|
||||
|
||||
You can override ``init``, ``step`` functions. There is also a ``sample`` function
|
||||
that you can override to generate samples ever time it switches between training and validation.
|
||||
"""
|
||||
state_modules: List[StateModule]
|
||||
|
||||
mode: ModeState
|
||||
|
||||
epochs: int = 10
|
||||
|
||||
trainer: Trainer
|
||||
validator: Trainer
|
||||
train_loader: torch.utils.data.DataLoader
|
||||
valid_loader: torch.utils.data.DataLoader
|
||||
|
||||
loop_count = '_data_loop_count'
|
||||
loop_step = None
|
||||
|
||||
inner_iterations: int = 1
|
||||
|
||||
is_track_time: bool = False
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||
raise NotImplementedError
|
||||
|
||||
def run_step(self):
|
||||
for i in range(self.inner_iterations):
|
||||
with tracker.namespace('sample'):
|
||||
self.sample()
|
||||
with self.mode.update(is_train=True):
|
||||
with tracker.namespace('train'):
|
||||
self.trainer()
|
||||
if self.validator:
|
||||
with tracker.namespace('valid'):
|
||||
self.validator()
|
||||
tracker.save()
|
||||
|
||||
def run(self):
|
||||
with monit.section("Initialize"):
|
||||
self.init()
|
||||
_ = self.validator
|
||||
_ = self.trainer
|
||||
for _ in self.training_loop:
|
||||
self.run_step()
|
||||
|
||||
def sample(self):
|
||||
pass
|
||||
|
||||
|
||||
@option(TrainValidConfigs.trainer)
|
||||
def _default_trainer(c: TrainValidConfigs):
|
||||
return Trainer(name='Train',
|
||||
mode=c.mode,
|
||||
data_loader=c.train_loader,
|
||||
inner_iterations=c.inner_iterations,
|
||||
state_modules=c.state_modules,
|
||||
is_track_time=c.is_track_time,
|
||||
step=c.step)
|
||||
|
||||
|
||||
@option(TrainValidConfigs.validator)
|
||||
def _default_validator(c: TrainValidConfigs):
|
||||
return Trainer(name='Valid',
|
||||
mode=c.mode,
|
||||
data_loader=c.valid_loader,
|
||||
inner_iterations=c.inner_iterations,
|
||||
state_modules=c.state_modules,
|
||||
is_track_time=c.is_track_time,
|
||||
step=c.step)
|
||||
|
||||
|
||||
@option(TrainValidConfigs.loop_count)
|
||||
def _data_loop_count(c: TrainValidConfigs):
|
||||
return c.epochs
|
||||
|
||||
|
||||
class SimpleTrainValidConfigs(TrainValidConfigs):
|
||||
r"""
|
||||
This is a configurable module that works for many standard DL experiments.
|
||||
|
||||
Arguments:
|
||||
model: A PyTorch model.
|
||||
optimizer: A PyTorch optimizer to update model.
|
||||
device: The device to train the model on. This defaults to a configurable device
|
||||
loss_function: A function to calculate the loss. This should accept ``model_output, target`` as
|
||||
arguments.
|
||||
update_batches (int): Number of batches to accumulate before taking an optimizer step.
|
||||
Defaults to ``1``.
|
||||
log_params_updates (int): How often (number of batches) to track model parameters and gradients.
|
||||
Defaults to a large number; i.e. logs every epoch.
|
||||
log_activations_batches (int): How often to log model activations.
|
||||
Defaults to a large number; i.e. logs every epoch.
|
||||
log_save_batches (int): How often to call :func:`labml.tracker.save`.
|
||||
"""
|
||||
optimizer: torch.optim.Adam
|
||||
model: nn.Module
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
loss_func: nn.Module
|
||||
|
||||
update_batches: int = 1
|
||||
log_params_updates: int = 2 ** 32 # 0 if not
|
||||
log_activations_batches: int = 2 ** 32 # 0 if not
|
||||
log_save_batches: int = 1
|
||||
|
||||
state_modules: List[StateModule] = []
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||
self.model.train(self.mode.is_train)
|
||||
data, target = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
|
||||
with monit.section("model"):
|
||||
with self.mode.update(is_log_activations=is_log_activations):
|
||||
output = self.model(data)
|
||||
|
||||
loss = self.loss_func(output, target)
|
||||
tracker.add("loss.", loss)
|
||||
|
||||
if self.mode.is_train:
|
||||
with monit.section('backward'):
|
||||
loss.backward()
|
||||
|
||||
if batch_idx.is_interval(self.update_batches):
|
||||
with monit.section('optimize'):
|
||||
self.optimizer.step()
|
||||
if batch_idx.is_interval(self.log_params_updates):
|
||||
tracker.add('model', self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if batch_idx.is_interval(self.log_save_batches):
|
||||
tracker.save()
|
||||
|
||||
|
||||
meta_config(SimpleTrainValidConfigs.update_batches,
|
||||
SimpleTrainValidConfigs.log_params_updates,
|
||||
SimpleTrainValidConfigs.log_activations_batches)
|
||||
|
||||
|
||||
@option(SimpleTrainValidConfigs.optimizer)
|
||||
def _default_optimizer(c: SimpleTrainValidConfigs):
|
||||
from .optimizer import OptimizerConfigs
|
||||
opt_conf = OptimizerConfigs()
|
||||
opt_conf.parameters = c.model.parameters()
|
||||
return opt_conf
|
@ -19,7 +19,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from labml import lab, monit, tracker
|
||||
from labml.configs import BaseConfigs, option
|
||||
from labml.utils.download import download_file
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.lora.gpt2 import GPTModel
|
||||
|
||||
|
||||
|
@ -11,11 +11,10 @@ import torch.utils.data
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.seed import SeedConfigs
|
||||
from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.helpers.datasets import MNISTConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex, hook_model_outputs
|
||||
from labml_nn.optimizers.configs import OptimizerConfigs
|
||||
|
||||
|
||||
@ -48,7 +47,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
"""
|
||||
optimizer: torch.optim.Adam
|
||||
model: nn.Module
|
||||
set_seed = SeedConfigs()
|
||||
device: torch.device = DeviceConfigs()
|
||||
epochs: int = 10
|
||||
|
||||
@ -126,7 +124,6 @@ def main():
|
||||
# Specify the optimizer
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 1.5e-4})
|
||||
conf.set_seed.set()
|
||||
experiment.add_pytorch_models(dict(model=conf.model))
|
||||
with experiment.start():
|
||||
conf.run()
|
||||
|
@ -18,7 +18,7 @@ MyAdam...[DONE] 1,192.89ms
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from labml_helpers.device import DeviceInfo
|
||||
from labml_nn.helpers.device import DeviceInfo
|
||||
from torch.optim import Adam as TorchAdam
|
||||
|
||||
from labml import monit
|
||||
|
@ -17,7 +17,7 @@ import torch
|
||||
|
||||
from labml import tracker, experiment, logger, monit
|
||||
from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam
|
||||
from labml_helpers.schedule import Piecewise
|
||||
from labml_nn.helpers.schedule import Piecewise
|
||||
from labml_nn.rl.dqn import QFuncLoss
|
||||
from labml_nn.rl.dqn.model import Model
|
||||
from labml_nn.rl.dqn.replay_buffer import ReplayBuffer
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
from labml import experiment, monit
|
||||
from labml import logger
|
||||
from labml.logger import Text
|
||||
from labml_helpers.datasets.text import TextDataset
|
||||
from labml_nn.helpers.datasets import TextDataset
|
||||
from labml_nn.sampling import Sampler
|
||||
from labml_nn.sampling.greedy import GreedySampler
|
||||
from labml_nn.sampling.nucleus import NucleusSampler
|
||||
|
@ -39,9 +39,9 @@ from matplotlib import pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from labml import lab, experiment, tracker, monit
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.optimizer import OptimizerConfigs
|
||||
from labml_nn.helpers.trainer import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
from torch import optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option, calculate
|
||||
from labml_helpers.datasets.text import SequentialUnBatchedDataset
|
||||
from labml_nn.helpers.datasets import SequentialUnBatchedDataset
|
||||
from labml_nn.transformers.alibi import AlibiMultiHeadAttention
|
||||
from labml_nn.experiments.nlp_autoregression import transpose_batch
|
||||
from labml_nn.transformers import TransformerConfigs
|
||||
|
@ -13,7 +13,7 @@ on an NLP auto-regression task (with Tiny Shakespeare dataset) with [Sophia-G op
|
||||
import torch
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml_helpers.train_valid import BatchIndex
|
||||
from labml_nn.helpers.trainer import BatchIndex
|
||||
from labml_nn.optimizers.sophia import Sophia
|
||||
from labml_nn.transformers.basic.autoregressive_experiment import Configs as TransformerAutoRegressionConfigs
|
||||
|
||||
|
@ -56,7 +56,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import TypedModuleList
|
||||
from labml_nn.transformers.feed_forward import FeedForward
|
||||
from labml_nn.transformers.mha import PrepareForMultiHeadAttention
|
||||
from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
|
||||
@ -233,7 +232,7 @@ class AttentionReconstructionLoss:
|
||||
attention reconstruction loss, we detach all other parameters except $f_c$
|
||||
from the gradient computation.
|
||||
"""
|
||||
def __init__(self, layers: TypedModuleList[CompressiveTransformerLayer]):
|
||||
def __init__(self, layers: nn.ModuleList):
|
||||
"""
|
||||
`layers` is the list of Compressive Transformer layers
|
||||
"""
|
||||
|
@ -16,8 +16,8 @@ import torch.nn as nn
|
||||
from labml import experiment, tracker, monit, logger
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml_helpers.metrics.simple_state import SimpleStateModule
|
||||
from labml_helpers.train_valid import BatchIndex, hook_model_outputs
|
||||
from labml_nn.helpers.metrics import SimpleStateModule
|
||||
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
|
||||
CompressiveTransformerLayer, Conv1dCompression
|
||||
|
@ -16,8 +16,8 @@ from torch import nn
|
||||
from labml import experiment, tracker, logger
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.train_valid import BatchIndex
|
||||
from labml_nn.helpers.metrics import Accuracy
|
||||
from labml_nn.helpers.trainer import BatchIndex
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
from labml_nn.transformers import Encoder, Generator
|
||||
from labml_nn.transformers import TransformerConfigs
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from labml import lab, monit
|
||||
from labml_helpers.datasets.text import TextFileDataset
|
||||
from labml_nn.helpers.datasets import TextFileDataset
|
||||
from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ import torch
|
||||
from torch.utils.data import Dataset as PyTorchDataset
|
||||
|
||||
from labml import lab, monit
|
||||
from labml_helpers.datasets.text import TextFileDataset, TextDataset
|
||||
from labml_nn.helpers.datasets import TextFileDataset, TextDataset
|
||||
from labml_nn.transformers.retro.database import RetroIndex
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
||||
|
||||
from labml import monit, lab, tracker, experiment, logger
|
||||
from labml.logger import Text
|
||||
from labml_helpers.datasets.text import TextFileDataset
|
||||
from labml_nn.helpers.datasets import TextFileDataset
|
||||
from labml_nn.optimizers.noam import Noam
|
||||
from labml_nn.transformers.retro import model as retro
|
||||
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
|
||||
|
@ -16,7 +16,7 @@ import torch.nn as nn
|
||||
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.train_valid import BatchIndex
|
||||
from labml_nn.helpers.trainer import BatchIndex
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
|
||||
|
||||
|
@ -16,8 +16,8 @@ from labml.logger import Text
|
||||
|
||||
from labml import experiment, tracker, monit, logger
|
||||
from labml.configs import option
|
||||
from labml_helpers.metrics.simple_state import SimpleStateModule
|
||||
from labml_helpers.train_valid import BatchIndex, hook_model_outputs
|
||||
from labml_nn.helpers.metrics import SimpleStateModule
|
||||
from labml_nn.helpers.trainer import BatchIndex, hook_model_outputs
|
||||
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
|
||||
from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
|
||||
|
||||
|
@ -18,8 +18,8 @@ import torch.utils.data
|
||||
|
||||
from labml import tracker, experiment
|
||||
from labml.configs import option, calculate
|
||||
from labml_helpers.schedule import Schedule, RelativePiecewise
|
||||
from labml_helpers.train_valid import BatchIndex
|
||||
from labml_nn.helpers.schedule import Schedule, RelativePiecewise
|
||||
from labml_nn.helpers.trainer import BatchIndex
|
||||
from labml_nn.experiments.mnist import MNISTConfigs
|
||||
from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
|
||||
CrossEntropyBayesRisk, SquaredErrorBayesRisk
|
||||
@ -178,7 +178,7 @@ def kl_div_coef(c: Configs):
|
||||
### KL Divergence Loss Coefficient Schedule
|
||||
"""
|
||||
|
||||
# Create a [relative piecewise schedule](https://docs.labml.ai/api/helpers.html#labml_helpers.schedule.Piecewise)
|
||||
# Create a [relative piecewise schedule](../../helpers/schedule.html)
|
||||
return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ from torch import nn
|
||||
|
||||
from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_nn.helpers.device import DeviceConfigs
|
||||
from labml_nn.unet.carvana import CarvanaDataset
|
||||
from labml_nn.unet import UNet
|
||||
|
||||
@ -34,7 +34,7 @@ class Configs(BaseConfigs):
|
||||
## Configurations
|
||||
"""
|
||||
# Device to train the model on.
|
||||
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
|
||||
# [`DeviceConfigs`](../helpers/device.html)
|
||||
# picks up an available CUDA device or defaults to CPU.
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
|
@ -8,19 +8,18 @@ summary: A bunch of utility functions and classes
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
|
||||
from labml_helpers.module import M, TypedModuleList
|
||||
|
||||
|
||||
def clone_module_list(module: M, n: int) -> TypedModuleList[M]:
|
||||
def clone_module_list(module: nn.Module, n: int) -> nn.ModuleList:
|
||||
"""
|
||||
## Clone Module
|
||||
|
||||
Make a `nn.ModuleList` with clones of a given module
|
||||
"""
|
||||
return TypedModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
|
||||
|
||||
def cycle_dataloader(data_loader):
|
||||
|
Reference in New Issue
Block a user