mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
cleanup hook model outputs
This commit is contained in:
@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
|
||||
from torch.utils.data import IterableDataset, Dataset
|
||||
|
||||
|
||||
def _dataset(is_train, transform):
|
||||
def _mnist_dataset(is_train, transform):
|
||||
return datasets.MNIST(str(lab.get_data_path()),
|
||||
train=is_train,
|
||||
download=True,
|
||||
@ -66,12 +66,12 @@ def mnist_transforms():
|
||||
|
||||
@option(MNISTConfigs.train_dataset)
|
||||
def mnist_train_dataset(c: MNISTConfigs):
|
||||
return _dataset(True, c.dataset_transforms)
|
||||
return _mnist_dataset(True, c.dataset_transforms)
|
||||
|
||||
|
||||
@option(MNISTConfigs.valid_dataset)
|
||||
def mnist_valid_dataset(c: MNISTConfigs):
|
||||
return _dataset(False, c.dataset_transforms)
|
||||
return _mnist_dataset(False, c.dataset_transforms)
|
||||
|
||||
|
||||
@option(MNISTConfigs.train_loader)
|
||||
@ -96,7 +96,7 @@ aggregate(MNISTConfigs.dataset_name, 'MNIST',
|
||||
(MNISTConfigs.valid_loader, 'mnist_valid_loader'))
|
||||
|
||||
|
||||
def _dataset(is_train, transform):
|
||||
def _cifar_dataset(is_train, transform):
|
||||
return datasets.CIFAR10(str(lab.get_data_path()),
|
||||
train=is_train,
|
||||
download=True,
|
||||
@ -147,12 +147,12 @@ def cifar10_transforms():
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.train_dataset)
|
||||
def cifar10_train_dataset(c: CIFAR10Configs):
|
||||
return _dataset(True, c.dataset_transforms)
|
||||
return _cifar_dataset(True, c.dataset_transforms)
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.valid_dataset)
|
||||
def cifar10_valid_dataset(c: CIFAR10Configs):
|
||||
return _dataset(False, c.dataset_transforms)
|
||||
return _cifar_dataset(False, c.dataset_transforms)
|
||||
|
||||
|
||||
@CIFAR10Configs.calc(CIFAR10Configs.train_loader)
|
||||
|
||||
Reference in New Issue
Block a user