12from typing import List, Optional
13
14from torch import nn
15
16from labml import experiment
17from labml.configs import option
18from labml_nn.experiments.cifar10 import CIFAR10Configs
19from labml_nn.resnet import ResNetBaseWe use CIFAR10Configs
 which defines all the dataset related configurations, optimizer, and a training loop.
22class Configs(CIFAR10Configs):Number fo blocks for each feature map size
31    n_blocks: List[int] = [3, 3, 3]Number of channels for each feature map size
33    n_channels: List[int] = [16, 32, 64]Bottleneck sizes
35    bottlenecks: Optional[List[int]] = NoneKernel size of the initial convolution layer
37    first_kernel_size: int = 340@option(Configs.model)
41def _resnet(c: Configs):46    base = ResNetBase(c.n_blocks, c.n_channels, c.bottlenecks, img_channels=3, first_kernel_size=c.first_kernel_size)Linear layer for classification
48    classification = nn.Linear(c.n_channels[-1], 10)Stack them
51    model = nn.Sequential(base, classification)Move the model to the device
53    return model.to(c.device)56def main():Create experiment
58    experiment.create(name='resnet', comment='cifar10')Create configurations
60    conf = Configs()Load configurations
62    experiment.configs(conf, {
63        'bottlenecks': [8, 16, 16],
64        'n_blocks': [6, 6, 6],
65
66        'optimizer.optimizer': 'Adam',
67        'optimizer.learning_rate': 2.5e-4,
68
69        'epochs': 500,
70        'train_batch_size': 256,
71
72        'train_dataset': 'cifar10_train_augmented',
73        'valid_dataset': 'cifar10_valid_no_augment',
74    })Set model for saving/loading
76    experiment.add_pytorch_models({'model': conf.model})Start the experiment and run the training loop
78    with experiment.start():
79        conf.run()83if __name__ == '__main__':
84    main()