mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
accuracy capsnet fix
This commit is contained in:
@ -14,7 +14,7 @@ import labml.utils.pytorch as pytorch_utils
|
|||||||
from labml import experiment, tracker
|
from labml import experiment, tracker
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||||
from labml_helpers.metrics.accuracy import Accuracy
|
from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||||
from labml_nn.capsule_networks import Squash, Router, MarginLoss
|
from labml_nn.capsule_networks import Squash, Router, MarginLoss
|
||||||
@ -99,7 +99,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
|||||||
model: nn.Module = 'capsule_network_model'
|
model: nn.Module = 'capsule_network_model'
|
||||||
reconstruction_loss = nn.MSELoss()
|
reconstruction_loss = nn.MSELoss()
|
||||||
margin_loss = MarginLoss(n_labels=10)
|
margin_loss = MarginLoss(n_labels=10)
|
||||||
accuracy = Accuracy()
|
accuracy = AccuracyDirect()
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
# Print losses and accuracy to screen
|
# Print losses and accuracy to screen
|
||||||
|
|||||||
Reference in New Issue
Block a user