cleanup hook model outputs

This commit is contained in:
Varuna Jayasiri
2025-07-20 09:02:34 +05:30
parent 5bdedcffec
commit a713c92b82
12 changed files with 36 additions and 142 deletions

View File

@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset, Dataset
def _dataset(is_train, transform):
def _mnist_dataset(is_train, transform):
return datasets.MNIST(str(lab.get_data_path()),
train=is_train,
download=True,
@ -66,12 +66,12 @@ def mnist_transforms():
@option(MNISTConfigs.train_dataset)
def mnist_train_dataset(c: MNISTConfigs):
return _dataset(True, c.dataset_transforms)
return _mnist_dataset(True, c.dataset_transforms)
@option(MNISTConfigs.valid_dataset)
def mnist_valid_dataset(c: MNISTConfigs):
return _dataset(False, c.dataset_transforms)
return _mnist_dataset(False, c.dataset_transforms)
@option(MNISTConfigs.train_loader)
@ -96,7 +96,7 @@ aggregate(MNISTConfigs.dataset_name, 'MNIST',
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
def _dataset(is_train, transform):
def _cifar_dataset(is_train, transform):
return datasets.CIFAR10(str(lab.get_data_path()),
train=is_train,
download=True,
@ -147,12 +147,12 @@ def cifar10_transforms():
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
def cifar10_train_dataset(c: CIFAR10Configs):
return _dataset(True, c.dataset_transforms)
return _cifar_dataset(True, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
def cifar10_valid_dataset(c: CIFAR10Configs):
return _dataset(False, c.dataset_transforms)
return _cifar_dataset(False, c.dataset_transforms)
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)