accuracy capsnet fix

This commit is contained in:
Varuna Jayasiri
2020-11-18 09:52:09 +05:30
parent d9c4d68e1c
commit 5ecec73cc5

View File

@ -14,7 +14,7 @@ import labml.utils.pytorch as pytorch_utils
from labml import experiment, tracker from labml import experiment, tracker
from labml.configs import option from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.metrics.accuracy import Accuracy from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect
from labml_helpers.module import Module from labml_helpers.module import Module
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.capsule_networks import Squash, Router, MarginLoss from labml_nn.capsule_networks import Squash, Router, MarginLoss
@ -99,7 +99,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
model: nn.Module = 'capsule_network_model' model: nn.Module = 'capsule_network_model'
reconstruction_loss = nn.MSELoss() reconstruction_loss = nn.MSELoss()
margin_loss = MarginLoss(n_labels=10) margin_loss = MarginLoss(n_labels=10)
accuracy = Accuracy() accuracy = AccuracyDirect()
def init(self): def init(self):
# Print losses and accuracy to screen # Print losses and accuracy to screen