mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
accuracy capsnet fix
This commit is contained in:
@ -14,7 +14,7 @@ import labml.utils.pytorch as pytorch_utils
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy
|
||||
from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||
from labml_nn.capsule_networks import Squash, Router, MarginLoss
|
||||
@ -99,7 +99,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
||||
model: nn.Module = 'capsule_network_model'
|
||||
reconstruction_loss = nn.MSELoss()
|
||||
margin_loss = MarginLoss(n_labels=10)
|
||||
accuracy = Accuracy()
|
||||
accuracy = AccuracyDirect()
|
||||
|
||||
def init(self):
|
||||
# Print losses and accuracy to screen
|
||||
|
||||
Reference in New Issue
Block a user