MNISTඅත්හදා බැලීම

11import torch.nn as nn
12import torch.utils.data
13from labml_helpers.module import Module
14
15from labml import tracker
16from labml.configs import option
17from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
18from labml_helpers.device import DeviceConfigs
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
21from labml_nn.optimizers.configs import OptimizerConfigs

පුහුණුකරුමානකරණ

24class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):

ප්රශස්තකරණය

32    optimizer: torch.optim.Adam

පුහුණුඋපාංගය

34    device: torch.device = DeviceConfigs()

වර්ගීකරණආකෘතිය

37    model: Module

පුහුණුකිරීමට එපොච් ගණන

39    epochs: int = 10

එපෝච්තුළ පුහුණුව සහ වලංගු කිරීම අතර මාරු වීමට වාර ගණන

42    inner_iterations = 10

නිරවද්යතාශ්රිතය

45    accuracy = Accuracy()

පාඩුශ්රිතය

47    loss_func = nn.CrossEntropyLoss()

ආරම්භකකරණය

49    def init(self):

ට්රැකර්වින්යාසයන් සකසන්න

54        tracker.set_scalar("loss.*", True)
55        tracker.set_scalar("accuracy.*", True)

මොඩියුලප්රතිදානයන් ලොග් කිරීමට කොක්කක් එක් කරන්න

57        hook_model_outputs(self.mode, self.model, 'model')

රාජ්යමොඩියුලයක් ලෙස නිරවද්යතාව එක් කරන්න. RNs සඳහා පුහුණුව සහ වලංගු කිරීම අතර රාජ්යයන් ගබඩා කිරීම අදහස් කරන බැවින් නම බොහෝ විට ව්යාකූල වේ. මෙය පුහුණුව සහ වලංගු කිරීම සඳහා නිරවද්යතා මෙට්රික් සංඛ්යාන වෙනම තබා ගනී.

62        self.state_modules = [self.accuracy]

පුහුණුවහෝ වලංගු කිරීමේ පියවර

64    def step(self, batch: any, batch_idx: BatchIndex):

පුහුණුව/ඇගයීම්මාදිලිය

70        self.model.train(self.mode.is_train)

උපාංගයවෙත දත්ත ගෙනයන්න

73        data, target = batch[0].to(self.device), batch[1].to(self.device)

පුහුණුප්රකාරයේදී ගෝලීය පියවර (සැකසූ සාම්පල ගණන) යාවත්කාලීන කරන්න

76        if self.mode.is_train:
77            tracker.add_global_step(len(data))

ආකෘතිප්රතිදානයන් ග්රහණය කර ගත යුතුද යන්න

80        with self.mode.update(is_log_activations=batch_idx.is_last):

ආදර්ශප්රතිදානයන් ලබා ගන්න.

82            output = self.model(data)

ගණනයකිරීම සහ ලොග් වීම

85        loss = self.loss_func(output, target)
86        tracker.add("loss.", loss)

ගණනයකිරීම සහ ලොග් කිරීමේ නිරවද්යතාවය

89        self.accuracy(output, target)
90        self.accuracy.track()

ආකෘතියපුහුණු කරන්න

93        if self.mode.is_train:

අනුක්රමිකගණනය කරන්න

95            loss.backward()

ප්රශස්තිකරණපියවර ගන්න

97            self.optimizer.step()

සෑමයුගලයකම අවසාන කණ්ඩායමේ ආදර්ශ පරාමිතීන් සහ අනුක්රමික ලොග් කරන්න

99            if batch_idx.is_last:
100                tracker.add('model', self.model)

අනුක්රමිකඉවත්

102            self.optimizer.zero_grad()

ලුහුබැඳඇති ප්රමිතික සුරකින්න

105        tracker.save()

පෙරනිමිප්රශස්තිකරණ වින්යාසයන්

108@option(MNISTConfigs.optimizer)
109def _optimizer(c: MNISTConfigs):
113    opt_conf = OptimizerConfigs()
114    opt_conf.parameters = c.model.parameters()
115    opt_conf.optimizer = 'Adam'
116    return opt_conf