CIFAR 10 මත කොන්වී මික්සර් පුහුණු කරන්න

මෙමතිර රචනය CIFAR මත ConvMixer දුම්රිය 10 දත්ත සමුදාය.

මෙයකඩදාසි ප්රති results ල ප්රතිනිෂ්පාදනය කිරීමේ උත්සාහයක් නොවේ. පුහුණුව සඳහා PyTorch Image Models (timm) හි ඇති රූප වැඩිදියුණු කිරීම් භාවිතා කරයි. සරල බව සඳහා අපි මෙය කර නැත - එමඟින් අපගේ වලංගු කිරීමේ නිරවද්යතාවය පහත වැටීමට හේතු වේ.

View Run

20from labml import experiment
21from labml.configs import option
22from labml_nn.experiments.cifar10 import CIFAR10Configs

වින්යාසකිරීම්

සියලුමදත්ත කට්ටල ආශ්රිත වින්යාසයන්, ප්රශස්තකරණය සහ පුහුණු ලූපයක් නිර්වචනය කරන අපි භාවිතා CIFAR10Configs කරමු.

25class Configs(CIFAR10Configs):

පැච්එකක ප්රමාණය,

34    patch_size: int = 2

පැච්කාවැද්දීම් වල නාලිකා ගණන,

36    d_model: int = 256

ConvMixer ස්ථර හෝ ගැඹුර ගණන,

38    n_layers: int = 8

ගැඹුර-නැණවත්සංවහනයේ කර්නල් ප්රමාණය,

40    kernel_size: int = 7

කර්තව්යයේපන්ති ගණන

42    n_classes: int = 10

ආකෘතියසාදන්න

45@option(Configs.model)
46def _conv_mixer(c: Configs):
50    from labml_nn.conv_mixer import ConvMixerLayer, ConvMixer, ClassificationHead, PatchEmbeddings

ConvMixerනිර්මාණය

53    return ConvMixer(ConvMixerLayer(c.d_model, c.kernel_size), c.n_layers,
54                     PatchEmbeddings(c.d_model, c.patch_size, 3),
55                     ClassificationHead(c.d_model, c.n_classes)).to(c.device)
58def main():

අත්හදාබැලීම සාදන්න

60    experiment.create(name='ConvMixer', comment='cifar10')

වින්යාසයන්සාදන්න

62    conf = Configs()

වින්යාසයන්පූරණය කරන්න

64    experiment.configs(conf, {

ප්රශස්තකරණය

66        'optimizer.optimizer': 'Adam',
67        'optimizer.learning_rate': 2.5e-4,

ඊපොච්සහ කණ්ඩායම් ප්රමාණය පුහුණු කිරීම

70        'epochs': 150,
71        'train_batch_size': 64,

සරලරූප වැඩි කිරීම

74        'train_dataset': 'cifar10_train_augmented',

වලංගුකිරීම සඳහා රූප වැඩි නොකරන්න

76        'valid_dataset': 'cifar10_valid_no_augment',
77    })

ඉතිරිකිරීම/පැටවීම සඳහා ආකෘතිය සකසන්න

79    experiment.add_pytorch_models({'model': conf.model})

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

81    with experiment.start():
82        conf.run()

86if __name__ == '__main__':
87    main()