diff --git a/labml_nn/capsule_networks/__init__.py b/labml_nn/capsule_networks/__init__.py index 44db731f..95ff1ca4 100644 --- a/labml_nn/capsule_networks/__init__.py +++ b/labml_nn/capsule_networks/__init__.py @@ -2,17 +2,17 @@ --- title: Capsule Networks summary: > - PyTorch implementation/tutorial of Capsule Networks. + PyTorch implementation and tutorial of Capsule Networks. Capsule networks is neural network architecture that embeds features as capsules and routes them with a voting mechanism to next layer of capsules. --- # Capsule Networks -This is an implementation of [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829). +This is a PyTorch implementation and tutorial of [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829). -Capsule networks is neural network architecture that embeds features as capsules and routes them -with a voting mechanism to next layer of capsules. +Capsule networks is neural network architecture that embeds features +as capsules and routes them with a voting mechanism to next layer of capsules. Unlike in other implementations of models, we've included a sample, because it is difficult to understand some of the concepts with just the modules. diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py index a0eb137d..e011ef1f 100644 --- a/labml_nn/capsule_networks/mnist.py +++ b/labml_nn/capsule_networks/mnist.py @@ -6,6 +6,8 @@ summary: Code for training Capsule Networks on MNIST dataset # Classify MNIST digits with Capsule Networks +This is an annotated PyTorch code to classify MNIST digits with PyTorch. + This paper implements the experiment described in paper [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829). """ @@ -161,11 +163,13 @@ def main(): """ Run the experiment """ - conf = Configs() experiment.create(name='capsule_network_mnist') + conf = Configs() experiment.configs(conf, {'optimizer.optimizer': 'Adam', - 'optimizer.learning_rate': 1e-3, - 'device.cuda_device': 1}) + 'optimizer.learning_rate': 1e-3}) + + experiment.add_pytorch_models({'model': conf.model}) + with experiment.start(): conf.run()