CIFAR10 හි විශාල ආකෘතියක් පුහුණු කරන්න

මෙය ආසවනයසඳහා CIFAR 10 හි විශාල ආකෘතියක් පුහුණු කරයි.

View Run

15import torch.nn as nn
16
17from labml import experiment, logger
18from labml.configs import option
19from labml_nn.experiments.cifar10 import CIFAR10Configs, CIFAR10VGGModel
20from labml_nn.normalization.batch_norm import BatchNorm

වින්යාසකිරීම්

සියලුමදත්ත කට්ටල ආශ්රිත වින්යාසයන්, ප්රශස්තකරණය සහ පුහුණු ලූපයක් නිර්වචනය කරන අපි භාවිතා CIFAR10Configs කරමු.

23class Configs(CIFAR10Configs):
30    pass

CIFA-10වර්ගීකරණය සඳහා VGG විලාසිතාවේ ආකෘතිය

මෙය සාමාන්ය VGG විලාසිතාවේ ගෘහ නිර්මාණ ශිල්පයෙන්ලබා ගනී.

33class LargeModel(CIFAR10VGGModel):

සංවහනස්තරයක් සහ සක්රිය කිරීම් සාදන්න

40    def conv_block(self, in_channels, out_channels) -> nn.Module:
44        return nn.Sequential(

හැලීම

46            nn.Dropout(0.1),

සංවහනස්ථරය

48            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),

කණ්ඩායම්සාමාන්යකරණය

50            BatchNorm(out_channels, track_running_stats=False),

Reluසක්රිය

52            nn.ReLU(inplace=True),
53        )
55    def __init__(self):

ලබාදී ඇති සංවහන ප්රමාණ (නාලිකා) සහිත ආකෘතියක් සාදන්න

57        super().__init__([[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]])

ආකෘතියසාදන්න

60@option(Configs.model)
61def _large_model(c: Configs):
65    return LargeModel().to(c.device)
68def main():

අත්හදාබැලීම සාදන්න

70    experiment.create(name='cifar10', comment='large model')

වින්යාසයන්සාදන්න

72    conf = Configs()

වින්යාසයන්පූරණය කරන්න

74    experiment.configs(conf, {
75        'optimizer.optimizer': 'Adam',
76        'optimizer.learning_rate': 2.5e-4,
77        'is_save_models': True,
78        'epochs': 20,
79    })

ඉතිරිකිරීම/පැටවීම සඳහා ආකෘතිය සකසන්න

81    experiment.add_pytorch_models({'model': conf.model})

ආකෘතියේපරාමිති ගණන මුද්රණය කරන්න

83    logger.inspect(params=(sum(p.numel() for p in conf.model.parameters() if p.requires_grad)))

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

85    with experiment.start():
86        conf.run()

90if __name__ == '__main__':
91    main()