mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			114 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			114 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| ---
 | |
| title: CIFAR10 Experiment
 | |
| summary: >
 | |
|   This is a reusable trainer for CIFAR10 dataset
 | |
| ---
 | |
| 
 | |
| # CIFAR10 Experiment
 | |
| """
 | |
| from typing import List
 | |
| 
 | |
| import torch.nn as nn
 | |
| 
 | |
| from labml import lab
 | |
| from labml.configs import option
 | |
| from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
 | |
| from labml_helpers.module import Module
 | |
| from labml_nn.experiments.mnist import MNISTConfigs
 | |
| 
 | |
| 
 | |
| class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):
 | |
|     """
 | |
|     ## Configurations
 | |
| 
 | |
|     This extends from CIFAR 10 dataset configurations from
 | |
|      [`labml_helpers`](https://github.com/labmlai/labml/tree/master/helpers)
 | |
|      and [`MNISTConfigs`](mnist.html).
 | |
|     """
 | |
|     # Use CIFAR10 dataset by default
 | |
|     dataset_name: str = 'CIFAR10'
 | |
| 
 | |
| 
 | |
| @option(CIFAR10Configs.train_dataset)
 | |
| def cifar10_train_augmented():
 | |
|     """
 | |
|     ### Augmented CIFAR 10 train dataset
 | |
|     """
 | |
|     from torchvision.datasets import CIFAR10
 | |
|     from torchvision.transforms import transforms
 | |
|     return CIFAR10(str(lab.get_data_path()),
 | |
|                    train=True,
 | |
|                    download=True,
 | |
|                    transform=transforms.Compose([
 | |
|                        # Pad and crop
 | |
|                        transforms.RandomCrop(32, padding=4),
 | |
|                        # Random horizontal flip
 | |
|                        transforms.RandomHorizontalFlip(),
 | |
|                        #
 | |
|                        transforms.ToTensor(),
 | |
|                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
 | |
|                    ]))
 | |
| 
 | |
| 
 | |
| @option(CIFAR10Configs.valid_dataset)
 | |
| def cifar10_valid_no_augment():
 | |
|     """
 | |
|     ### Non-augmented CIFAR 10 validation dataset
 | |
|     """
 | |
|     from torchvision.datasets import CIFAR10
 | |
|     from torchvision.transforms import transforms
 | |
|     return CIFAR10(str(lab.get_data_path()),
 | |
|                    train=False,
 | |
|                    download=True,
 | |
|                    transform=transforms.Compose([
 | |
|                        transforms.ToTensor(),
 | |
|                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
 | |
|                    ]))
 | |
| 
 | |
| 
 | |
| class CIFAR10VGGModel(Module):
 | |
|     """
 | |
|     ### VGG model for CIFAR-10 classification
 | |
|     """
 | |
| 
 | |
|     def conv_block(self, in_channels, out_channels) -> nn.Module:
 | |
|         """
 | |
|         Convolution and activation combined
 | |
|         """
 | |
|         return nn.Sequential(
 | |
|             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
 | |
|             nn.ReLU(inplace=True),
 | |
|         )
 | |
| 
 | |
|     def __init__(self, blocks: List[List[int]]):
 | |
|         super().__init__()
 | |
| 
 | |
|         # 5 $2 \times 2$ pooling layers will produce a output of size $1 \ times 1$.
 | |
|         # CIFAR 10 image size is $32 \times 32$
 | |
|         assert len(blocks) == 5
 | |
|         layers = []
 | |
|         # RGB channels
 | |
|         in_channels = 3
 | |
|         # Number of channels in each layer in each block
 | |
|         for block in blocks:
 | |
|             # Convolution, Normalization and Activation layers
 | |
|             for channels in block:
 | |
|                 layers += self.conv_block(in_channels, channels)
 | |
|                 in_channels = channels
 | |
|             # Max pooling at end of each block
 | |
|             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
 | |
| 
 | |
|         # Create a sequential model with the layers
 | |
|         self.layers = nn.Sequential(*layers)
 | |
|         # Final logits layer
 | |
|         self.fc = nn.Linear(in_channels, 10)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         # The VGG layers
 | |
|         x = self.layers(x)
 | |
|         # Reshape for classification layer
 | |
|         x = x.view(x.shape[0], -1)
 | |
|         # Final linear layer
 | |
|         return self.fc(x)
 | 
