12from typing import List, Optional
13
14from torch import nn
15
16from labml import experiment
17from labml.configs import option
18from labml_nn.experiments.cifar10 import CIFAR10Configs
19from labml_nn.resnet import ResNetBase
සියලුමදත්ත කට්ටල ආශ්රිත වින්යාසයන්, ප්රශස්තකරණය සහ පුහුණු ලූපයක් නිර්වචනය කරන අපි භාවිතා CIFAR10Configs
කරමු.
22class Configs(CIFAR10Configs):
එක්එක් ලක්ෂණය සිතියම ප්රමාණය සඳහා කුට්ටි දැමිමේ අංකය
31 n_blocks: List[int] = [3, 3, 3]
එක්එක් විශේෂාංග සිතියම් ප්රමාණය සඳහා නාලිකා ගණන
33 n_channels: List[int] = [16, 32, 64]
බාධකප්රමාණ
35 bottlenecks: Optional[List[int]] = None
ආරම්භකසංවලිත ස්ථරයේ කර්නල් ප්රමාණය
37 first_kernel_size: int = 3
40@option(Configs.model)
41def _resnet(c: Configs):
46 base = ResNetBase(c.n_blocks, c.n_channels, c.bottlenecks, img_channels=3, first_kernel_size=c.first_kernel_size)
වර්ගීකරණයසඳහා රේඛීය ස්ථරය
48 classification = nn.Linear(c.n_channels[-1], 10)
ඒවාගොඩගසන්න
51 model = nn.Sequential(base, classification)
උපාංගයවෙත ආකෘතිය ගෙනයන්න
53 return model.to(c.device)
56def main():
අත්හදාබැලීම සාදන්න
58 experiment.create(name='resnet', comment='cifar10')
වින්යාසයන්සාදන්න
60 conf = Configs()
වින්යාසයන්පූරණය කරන්න
62 experiment.configs(conf, {
63 'bottlenecks': [8, 16, 16],
64 'n_blocks': [6, 6, 6],
65
66 'optimizer.optimizer': 'Adam',
67 'optimizer.learning_rate': 2.5e-4,
68
69 'epochs': 500,
70 'train_batch_size': 256,
71
72 'train_dataset': 'cifar10_train_augmented',
73 'valid_dataset': 'cifar10_valid_no_augment',
74 })
ඉතිරිකිරීම/පැටවීම සඳහා ආකෘතිය සකසන්න
76 experiment.add_pytorch_models({'model': conf.model})
අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න
78 with experiment.start():
79 conf.run()
83if __name__ == '__main__':
84 main()