diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py index 5fc81a9b..72ed7cda 100644 --- a/labml_nn/capsule_networks/mnist.py +++ b/labml_nn/capsule_networks/mnist.py @@ -14,7 +14,7 @@ import labml.utils.pytorch as pytorch_utils from labml import experiment, tracker from labml.configs import option from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.metrics.accuracy import Accuracy +from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect from labml_helpers.module import Module from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_nn.capsule_networks import Squash, Router, MarginLoss @@ -99,7 +99,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs): model: nn.Module = 'capsule_network_model' reconstruction_loss = nn.MSELoss() margin_loss = MarginLoss(n_labels=10) - accuracy = Accuracy() + accuracy = AccuracyDirect() def init(self): # Print losses and accuracy to screen