From 5ecec73cc5981d4fcf8c88ce5562abbe80821745 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 18 Nov 2020 09:52:09 +0530 Subject: [PATCH] accuracy capsnet fix --- labml_nn/capsule_networks/mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py index 5fc81a9b..72ed7cda 100644 --- a/labml_nn/capsule_networks/mnist.py +++ b/labml_nn/capsule_networks/mnist.py @@ -14,7 +14,7 @@ import labml.utils.pytorch as pytorch_utils from labml import experiment, tracker from labml.configs import option from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.metrics.accuracy import Accuracy +from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect from labml_helpers.module import Module from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_nn.capsule_networks import Squash, Router, MarginLoss @@ -99,7 +99,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs): model: nn.Module = 'capsule_network_model' reconstruction_loss = nn.MSELoss() margin_loss = MarginLoss(n_labels=10) - accuracy = Accuracy() + accuracy = AccuracyDirect() def init(self): # Print losses and accuracy to screen