14from typing import Any
15
16import torch.nn as nn
17import torch.utils.data
18
19from labml import tracker, experiment
20from labml.configs import option, calculate
21from labml_helpers.module import Module
22from labml_helpers.schedule import Schedule, RelativePiecewise
23from labml_helpers.train_valid import BatchIndex
24from labml_nn.experiments.mnist import MNISTConfigs
25from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
26    CrossEntropyBayesRisk, SquaredErrorBayesRisk

Lenetපදනම් කරගත් ආකෘතිය සිට MNIST වර්ගීකරණය

29class Model(Module):
34    def __init__(self, dropout: float):
35        super().__init__()

පළමු කැටි ගැසුණු ස්ථරය

37        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)

Reluසක්රිය

39        self.act1 = nn.ReLU()

උපරිම තටාක

41        self.max_pool1 = nn.MaxPool2d(2, 2)

දෙවන කැටි ගැසුණු ස්ථරය

43        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)

Reluසක්රිය

45        self.act2 = nn.ReLU()

උපරිම තටාක

47        self.max_pool2 = nn.MaxPool2d(2, 2)

විශේෂාංග සිතියම් ගත කරන පළමු පූර්ණ සම්බන්ධිත ස්ථරය

49        self.fc1 = nn.Linear(50 * 4 * 4, 500)

Reluසක්රිය

51        self.act3 = nn.ReLU()

පන්ති සඳහා නිමැවුම් සාක්ෂි සඳහා අවසාන පූර්ණ සම්බන්ධිත ස්ථරය. Negative ණාත්මක නොවන සාක්ෂි ලබා ගැනීම සඳහා ආකෘතියෙන් පිටත RelU හෝ Softplus සක්රිය කිරීම මේ සඳහා යොදනු ලැබේ

55        self.fc2 = nn.Linear(500, 10)

සැඟවුණුස්තරය සඳහා අතහැර දැමීම

57        self.dropout = nn.Dropout(p=dropout)
  • x හැඩයේ MNIST රූප කාණ්ඩයයි [batch_size, 1, 28, 28]
59    def __call__(self, x: torch.Tensor):

පළමුකැටි ගැසීම සහ උපරිම තටාක යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, 20, 12, 12]

65        x = self.max_pool1(self.act1(self.conv1(x)))

දෙවනකැටි ගැසීම සහ උපරිම තටාක යොදන්න. ප්රති result ලය හැඩය ඇත [batch_size, 50, 4, 4]

68        x = self.max_pool2(self.act2(self.conv2(x)))

ටෙන්සරයහැඩයට සමතලා කරන්න [batch_size, 50 * 4 * 4]

70        x = x.view(x.shape[0], -1)

සැඟවුණුස්ථරය යොදන්න

72        x = self.act3(self.fc1(x))

අතහැරදැමීම යොදන්න

74        x = self.dropout(x)

අවසානස්ථරය යොදන්න සහ ආපසු යන්න

76        return self.fc2(x)

වින්යාසකිරීම්

අපි MNISTConfigs වින්යාසයන් භාවිතා කරමු.

79class Configs(MNISTConfigs):
87    kl_div_loss = KLDivergenceLoss()

KLඅපසරනය විධිමත් කිරීමේ සංගුණකය කාලසටහන

89    kl_div_coef: Schedule

KLඅපසරනය විධිමත් කිරීමේ සංගුණකය කාලසටහන

91    kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]

ලුහුබැඳීමසඳහාසංඛ්යාන මොඩියුලය

93    stats = TrackStatistics()

හැලීම

95    dropout: float = 0.5

ආදර්ශප්රතිදානය ශුන්ය නොවන සාක්ෂි බවට පරිවර්තනය කිරීමේ මොඩියුලය

97    outputs_to_evidence: Module

ආරම්භකකරණය

99    def init(self):

ට්රැකර්වින්යාසයන් සකසන්න

104        tracker.set_scalar("loss.*", True)
105        tracker.set_scalar("accuracy.*", True)
106        tracker.set_histogram('u.*', True)
107        tracker.set_histogram('prob.*', False)
108        tracker.set_scalar('annealing_coef.*', False)
109        tracker.set_scalar('kl_div_loss.*', False)

112        self.state_modules = []

පුහුණුවහෝ වලංගු කිරීමේ පියවර

114    def step(self, batch: Any, batch_idx: BatchIndex):

පුහුණුව/ඇගයීම්මාදිලිය

120        self.model.train(self.mode.is_train)

උපාංගයවෙත දත්ත ගෙනයන්න

123        data, target = batch[0].to(self.device), batch[1].to(self.device)

එක්-උණුසුම්කේත කරන ලද ඉලක්ක

126        eye = torch.eye(10).to(torch.float).to(self.device)
127        target = eye[target]

පුහුණුප්රකාරයේදී ගෝලීය පියවර (සැකසූ සාම්පල ගණන) යාවත්කාලීන කරන්න

130        if self.mode.is_train:
131            tracker.add_global_step(len(data))

ආදර්ශප්රතිදානයන් ලබා ගන්න

134        outputs = self.model(data)

සාක්ෂිලබා ගන්න

136        evidence = self.outputs_to_evidence(outputs)

අලාභයගණනය කරන්න

139        loss = self.loss_func(evidence, target)

KLඅපසරනය විධිමත් කිරීමේ අලාභය ගණනය කරන්න

141        kl_div_loss = self.kl_div_loss(evidence, target)
142        tracker.add("loss.", loss)
143        tracker.add("kl_div_loss.", kl_div_loss)

KLඅපසරනය පාඩු සංගුණකය

146        annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
147        tracker.add("annealing_coef.", annealing_coef)

මුළුඅලාභය

150        loss = loss + annealing_coef * kl_div_loss

සංඛ්යාලේඛනනිරීක්ෂණය කරන්න

153        self.stats(evidence, target)

ආකෘතියපුහුණු කරන්න

156        if self.mode.is_train:

අනුක්රමිකගණනය කරන්න

158            loss.backward()

ප්රශස්තිකරණපියවර ගන්න

160            self.optimizer.step()

අනුක්රමිකඉවත්

162            self.optimizer.zero_grad()

ලුහුබැඳඇති ප්රමිතික සුරකින්න

165        tracker.save()

ආකෘතියසාදන්න

168@option(Configs.model)
169def mnist_model(c: Configs):
173    return Model(c.dropout).to(c.device)

KLඅපසරනය පාඩු සංගුණක උපලේඛනය

176@option(Configs.kl_div_coef)
177def kl_div_coef(c: Configs):
183    return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
187calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
189calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
191calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())

සාක්ෂිගණනය කිරීමට RELU

194calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())

සාක්ෂිගණනය කිරීමට සොෆ්ට්ප්ලස්

196calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
199def main():

අත්හදාබැලීම සාදන්න

201    experiment.create(name='evidence_mnist')

වින්යාසයන්සාදන්න

203    conf = Configs()

වින්යාසයන්පූරණය කරන්න

205    experiment.configs(conf, {
206        'optimizer.optimizer': 'Adam',
207        'optimizer.learning_rate': 0.001,
208        'optimizer.weight_decay': 0.005,

'loss_func':' max_likelihood_loss ',' අහිමි_func ':' cross_entropy_bayes_risk ',

212        'loss_func': 'squared_error_bayes_risk',
213
214        'outputs_to_evidence': 'softplus',
215
216        'dropout': 0.5,
217    })

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

219    with experiment.start():
220        conf.run()

224if __name__ == '__main__':
225    main()