Files
2025-07-20 09:02:34 +05:30

323 lines
10 KiB
Python

import random
from pathlib import PurePath, Path
from typing import List, Callable, Dict, Optional
from torchvision import datasets, transforms
import torch
from labml import lab
from labml import monit
from labml.configs import BaseConfigs
from labml.configs import aggregate, option
from labml.utils.download import download_file
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset, Dataset
def _mnist_dataset(is_train, transform):
return datasets.MNIST(str(lab.get_data_path()),
train=is_train,
download=True,
transform=transform)
class MNISTConfigs(BaseConfigs):
"""
Configurable MNIST data set.
Arguments:
dataset_name (str): name of the data set, ``MNIST``
dataset_transforms (torchvision.transforms.Compose): image transformations
train_dataset (torchvision.datasets.MNIST): training dataset
valid_dataset (torchvision.datasets.MNIST): validation dataset
train_loader (torch.utils.data.DataLoader): training data loader
valid_loader (torch.utils.data.DataLoader): validation data loader
train_batch_size (int): training batch size
valid_batch_size (int): validation batch size
train_loader_shuffle (bool): whether to shuffle training data
valid_loader_shuffle (bool): whether to shuffle validation data
"""
dataset_name: str = 'MNIST'
dataset_transforms: transforms.Compose
train_dataset: datasets.MNIST
valid_dataset: datasets.MNIST
train_loader: DataLoader
valid_loader: DataLoader
train_batch_size: int = 64
valid_batch_size: int = 1024
train_loader_shuffle: bool = True
valid_loader_shuffle: bool = False
@option(MNISTConfigs.dataset_transforms)
def mnist_transforms():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
@option(MNISTConfigs.train_dataset)
def mnist_train_dataset(c: MNISTConfigs):
return _mnist_dataset(True, c.dataset_transforms)
@option(MNISTConfigs.valid_dataset)
def mnist_valid_dataset(c: MNISTConfigs):
return _mnist_dataset(False, c.dataset_transforms)
@option(MNISTConfigs.train_loader)
def mnist_train_loader(c: MNISTConfigs):
return DataLoader(c.train_dataset,
batch_size=c.train_batch_size,
shuffle=c.train_loader_shuffle)
@option(MNISTConfigs.valid_loader)
def mnist_valid_loader(c: MNISTConfigs):
return DataLoader(c.valid_dataset,
batch_size=c.valid_batch_size,
shuffle=c.valid_loader_shuffle)
aggregate(MNISTConfigs.dataset_name, 'MNIST',
(MNISTConfigs.dataset_transforms, 'mnist_transforms'),
(MNISTConfigs.train_dataset, 'mnist_train_dataset'),
(MNISTConfigs.valid_dataset, 'mnist_valid_dataset'),
(MNISTConfigs.train_loader, 'mnist_train_loader'),
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
def _cifar_dataset(is_train, transform):
return datasets.CIFAR10(str(lab.get_data_path()),
train=is_train,
download=True,
transform=transform)
class CIFAR10Configs(BaseConfigs):
"""
Configurable CIFAR 10 data set.
Arguments:
dataset_name (str): name of the data set, ``CIFAR10``
dataset_transforms (torchvision.transforms.Compose): image transformations
train_dataset (torchvision.datasets.CIFAR10): training dataset
valid_dataset (torchvision.datasets.CIFAR10): validation dataset
train_loader (torch.utils.data.DataLoader): training data loader
valid_loader (torch.utils.data.DataLoader): validation data loader
train_batch_size (int): training batch size
valid_batch_size (int): validation batch size
train_loader_shuffle (bool): whether to shuffle training data
valid_loader_shuffle (bool): whether to shuffle validation data
"""
dataset_name: str = 'CIFAR10'
dataset_transforms: transforms.Compose
train_dataset: datasets.CIFAR10
valid_dataset: datasets.CIFAR10
train_loader: DataLoader
valid_loader: DataLoader
train_batch_size: int = 64
valid_batch_size: int = 1024
train_loader_shuffle: bool = True
valid_loader_shuffle: bool = False
@CIFAR10Configs.calc(CIFAR10Configs.dataset_transforms)
def cifar10_transforms():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
def cifar10_train_dataset(c: CIFAR10Configs):
return _cifar_dataset(True, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
def cifar10_valid_dataset(c: CIFAR10Configs):
return _cifar_dataset(False, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
def cifar10_train_loader(c: CIFAR10Configs):
return DataLoader(c.train_dataset,
batch_size=c.train_batch_size,
shuffle=c.train_loader_shuffle)
@CIFAR10Configs.calc(CIFAR10Configs.valid_loader)
def cifar10_valid_loader(c: CIFAR10Configs):
return DataLoader(c.valid_dataset,
batch_size=c.valid_batch_size,
shuffle=c.valid_loader_shuffle)
CIFAR10Configs.aggregate(CIFAR10Configs.dataset_name, 'CIFAR10',
(CIFAR10Configs.dataset_transforms, 'cifar10_transforms'),
(CIFAR10Configs.train_dataset, 'cifar10_train_dataset'),
(CIFAR10Configs.valid_dataset, 'cifar10_valid_dataset'),
(CIFAR10Configs.train_loader, 'cifar10_train_loader'),
(CIFAR10Configs.valid_loader, 'cifar10_valid_loader'))
class TextDataset:
itos: List[str]
stoi: Dict[str, int]
n_tokens: int
train: str
valid: str
standard_tokens: List[str] = []
@staticmethod
def load(path: PurePath):
with open(str(path), 'r') as f:
return f.read()
def __init__(self, path: PurePath, tokenizer: Callable, train: str, valid: str, test: str, *,
n_tokens: Optional[int] = None,
stoi: Optional[Dict[str, int]] = None,
itos: Optional[List[str]] = None):
self.test = test
self.valid = valid
self.train = train
self.tokenizer = tokenizer
self.path = path
if n_tokens or stoi or itos:
assert stoi and itos and n_tokens
self.n_tokens = n_tokens
self.stoi = stoi
self.itos = itos
else:
self.n_tokens = len(self.standard_tokens)
self.stoi = {t: i for i, t in enumerate(self.standard_tokens)}
with monit.section("Tokenize"):
tokens = self.tokenizer(self.train) + self.tokenizer(self.valid)
tokens = sorted(list(set(tokens)))
for t in monit.iterate("Build vocabulary", tokens):
self.stoi[t] = self.n_tokens
self.n_tokens += 1
self.itos = [''] * self.n_tokens
for t, n in self.stoi.items():
self.itos[n] = t
def text_to_i(self, text: str) -> torch.Tensor:
tokens = self.tokenizer(text)
return torch.tensor([self.stoi[s] for s in tokens if s in self.stoi], dtype=torch.long)
def __repr__(self):
return f'{len(self.train) / 1_000_000 :,.2f}M, {len(self.valid) / 1_000_000 :,.2f}M - {str(self.path)}'
class SequentialDataLoader(IterableDataset):
def __init__(self, *, text: str, dataset: TextDataset,
batch_size: int, seq_len: int):
self.seq_len = seq_len
data = dataset.text_to_i(text)
n_batch = data.shape[0] // batch_size
data = data.narrow(0, 0, n_batch * batch_size)
data = data.view(batch_size, -1).t().contiguous()
self.data = data
def __len__(self):
return self.data.shape[0] // self.seq_len
def __iter__(self):
self.idx = 0
return self
def __next__(self):
if self.idx >= self.data.shape[0] - 1:
raise StopIteration()
seq_len = min(self.seq_len, self.data.shape[0] - 1 - self.idx)
i = self.idx + seq_len
data = self.data[self.idx: i]
target = self.data[self.idx + 1: i + 1]
self.idx = i
return data, target
def __getitem__(self, idx):
seq_len = min(self.seq_len, self.data.shape[0] - 1 - idx)
i = idx + seq_len
data = self.data[idx: i]
target = self.data[idx + 1: i + 1]
return data, target
class SequentialUnBatchedDataset(Dataset):
def __init__(self, *, text: str, dataset: TextDataset,
seq_len: int,
is_random_offset: bool = True):
self.is_random_offset = is_random_offset
self.seq_len = seq_len
self.data = dataset.text_to_i(text)
def __len__(self):
return (self.data.shape[0] - 1) // self.seq_len
def __getitem__(self, idx):
start = idx * self.seq_len
assert start + self.seq_len + 1 <= self.data.shape[0]
if self.is_random_offset:
start += random.randint(0, min(self.seq_len - 1, self.data.shape[0] - (start + self.seq_len + 1)))
end = start + self.seq_len
data = self.data[start: end]
target = self.data[start + 1: end + 1]
return data, target
class TextFileDataset(TextDataset):
standard_tokens = []
def __init__(self, path: PurePath, tokenizer: Callable, *,
url: Optional[str] = None,
filter_subset: Optional[int] = None):
path = Path(path)
if not path.exists():
if not url:
raise FileNotFoundError(str(path))
else:
download_file(url, path)
with monit.section("Load data"):
text = self.load(path)
if filter_subset:
text = text[:filter_subset]
split = int(len(text) * .9)
train = text[:split]
valid = text[split:]
super().__init__(path, tokenizer, train, valid, '')
def _test_tiny_shakespeare():
from labml import lab
_ = TextFileDataset(lab.get_data_path() / 'tiny_shakespeare.txt', lambda x: list(x),
url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
if __name__ == '__main__':
_test_tiny_shakespeare()