11import torch.nn as nn
12import torch.utils.data
13from labml_helpers.module import Module
14
15from labml import tracker
16from labml.configs import option
17from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
18from labml_helpers.device import DeviceConfigs
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
21from labml_nn.optimizers.configs import OptimizerConfigs24class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):オプティマイザー
32 optimizer: torch.optim.Adamトレーニングデバイス
34 device: torch.device = DeviceConfigs()分類モデル
37 model: Moduleトレーニング対象エポックの数
39 epochs: int = 101 つのエポック内でトレーニングと検証を切り替える回数
42 inner_iterations = 10精度機能
45 accuracy = Accuracy()損失関数
47 loss_func = nn.CrossEntropyLoss()49 def init(self):トラッカー構成を設定
54 tracker.set_scalar("loss.*", True)
55 tracker.set_scalar("accuracy.*", True)モジュール出力をログに記録するフックを追加
57 hook_model_outputs(self.mode, self.model, 'model')ステートモジュールとして精度を追加してください。この名前は、RNN のトレーニングと検証の間の状態を保存するためのものなので、おそらくわかりにくいでしょう。これにより、精度指標の統計情報がトレーニング用と検証用に別々に保持されます。
62 self.state_modules = [self.accuracy]64 def step(self, batch: any, batch_idx: BatchIndex):トレーニング/評価モード
70 self.model.train(self.mode.is_train)データをデバイスに移動
73 data, target = batch[0].to(self.device), batch[1].to(self.device)トレーニングモード時にグローバルステップ (処理されたサンプル数) を更新
76 if self.mode.is_train:
77 tracker.add_global_step(len(data))モデル出力をキャプチャするかどうか
80 with self.mode.update(is_log_activations=batch_idx.is_last):モデル出力を取得します。
82 output = self.model(data)損失の計算と記録
85 loss = self.loss_func(output, target)
86 tracker.add("loss.", loss)精度の計算と記録
89 self.accuracy(output, target)
90 self.accuracy.track()モデルのトレーニング
93 if self.mode.is_train:勾配の計算
95 loss.backward()最適化の一歩を踏み出す
97 self.optimizer.step()各エポックの最後のバッチでモデルパラメータと勾配を記録します
99 if batch_idx.is_last:
100 tracker.add('model', self.model)グラデーションをクリア
102 self.optimizer.zero_grad()追跡したメトリクスを保存する
105 tracker.save()108@option(MNISTConfigs.optimizer)
109def _optimizer(c: MNISTConfigs):113 opt_conf = OptimizerConfigs()
114 opt_conf.parameters = c.model.parameters()
115 opt_conf.optimizer = 'Adam'
116 return opt_conf