10from typing import List
11
12import torch.nn as nn
13
14from labml import lab
15from labml.configs import option
16from labml_helpers.datasets.cifar10 import CIFAR10Configs as CIFAR10DatasetConfigs
17from labml_helpers.module import Module
18from labml_nn.experiments.mnist import MNISTConfigs21class CIFAR10Configs(CIFAR10DatasetConfigs, MNISTConfigs):デフォルトで CIFAR10 データセットを使用
30    dataset_name: str = 'CIFAR10'33@option(CIFAR10Configs.train_dataset)
34def cifar10_train_augmented():38    from torchvision.datasets import CIFAR10
39    from torchvision.transforms import transforms
40    return CIFAR10(str(lab.get_data_path()),
41                   train=True,
42                   download=True,
43                   transform=transforms.Compose([パッドとクロップ
45                       transforms.RandomCrop(32, padding=4),ランダム水平反転
47                       transforms.RandomHorizontalFlip(),49                       transforms.ToTensor(),
50                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
51                   ]))54@option(CIFAR10Configs.valid_dataset)
55def cifar10_valid_no_augment():59    from torchvision.datasets import CIFAR10
60    from torchvision.transforms import transforms
61    return CIFAR10(str(lab.get_data_path()),
62                   train=False,
63                   download=True,
64                   transform=transforms.Compose([
65                       transforms.ToTensor(),
66                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
67                   ]))70class CIFAR10VGGModel(Module):コンボリューションとアクティベーションの組み合わせ
75    def conv_block(self, in_channels, out_channels) -> nn.Module:79        return nn.Sequential(
80            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
81            nn.ReLU(inplace=True),
82        )84    def __init__(self, blocks: List[List[int]]):
85        super().__init__()5つのプーリングレイヤーでサイズの出力が得られます。CIFAR 10 の画像サイズは
89        assert len(blocks) == 5
90        layers = []RGB チャンネル
92        in_channels = 3各ブロックの各レイヤーのチャンネル数
94        for block in blocks:コンボリューション、ノーマライゼーション、アクティベーションレイヤー
96            for channels in block:
97                layers += self.conv_block(in_channels, channels)
98                in_channels = channels各ブロック終了時の最大プーリング
100            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]レイヤーを含むシーケンシャルモデルの作成
103        self.layers = nn.Sequential(*layers)最終ロジットレイヤー
105        self.fc = nn.Linear(in_channels, 10)107    def forward(self, x):VGG レイヤー
109        x = self.layers(x)分類レイヤーの形状を変更
111        x = x.view(x.shape[0], -1)最終線形レイヤー
113        return self.fc(x)