Files
2023-04-02 12:10:18 +05:30

90 lines
2.4 KiB
Python

"""
---
title: Train a large model on CIFAR 10
summary: >
Train a large model on CIFAR 10 for distillation.
---
# Train a large model on CIFAR 10
This trains a large model on CIFAR 10 for [distillation](index.html).
"""
import torch.nn as nn
from labml import experiment, logger
from labml.configs import option
from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
from labml_nn.normalization.batch_norm import BatchNorm
class Configs(CIFAR10Configs):
"""
## Configurations
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
dataset related configurations, optimizer, and a training loop.
"""
pass
class LargeModel(CIFAR10VGGModel):
"""
### VGG style model for CIFAR-10 classification
This derives from the [generic VGG style architecture](../experiments/cifar10.html).
"""
def conv_block(self, in_channels, out_channels) -> nn.Module:
"""
Create a convolution layer and the activations
"""
return nn.Sequential(
# Dropout
nn.Dropout(0.1),
# Convolution layer
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
# Batch normalization
BatchNorm(out_channels, track_running_stats=False),
# ReLU activation
nn.ReLU(inplace=True),
)
def __init__(self):
# Create a model with given convolution sizes (channels)
super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])
@option(Configs.model)
def _large_model(c: Configs):
"""
### Create model
"""
return LargeModel().to(c.device)
def main():
# Create experiment
experiment.create(name='cifar10', comment='large model')
# Create configurations
conf = Configs()
# Load configurations
experiment.configs(conf, {
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 2.5e-4,
'is_save_models': True,
'epochs': 20,
})
# Set model for saving/loading
experiment.add_pytorch_models({'model': conf.model})
# Print number of parameters in the model
logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))
# Start the experiment and run the training loop
with experiment.start():
conf.run()
#
if __name__ == '__main__':
main()