diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py index 6ed840c9..ea4df6ab 100644 --- a/labml_nn/capsule_networks/mnist.py +++ b/labml_nn/capsule_networks/mnist.py @@ -10,11 +10,10 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.data -import labml.utils.pytorch as pytorch_utils from labml import experiment, tracker from labml.configs import option from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect +from labml_helpers.metrics.accuracy import AccuracyDirect from labml_helpers.module import Module from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_nn.capsule_networks import Squash, Router, MarginLoss @@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs): self.optimizer.step() # Log parameters and gradients if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.model) + tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save() diff --git a/labml_nn/gan/simple_mnist_experiment.py b/labml_nn/gan/simple_mnist_experiment.py index 9e109120..b801946a 100644 --- a/labml_nn/gan/simple_mnist_experiment.py +++ b/labml_nn/gan/simple_mnist_experiment.py @@ -9,9 +9,8 @@ import torch.nn as nn import torch.utils.data from torchvision import transforms -import labml.utils.pytorch as pytorch_utils from labml import tracker, monit, experiment -from labml.configs import option, calculate +from labml.configs import option from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module @@ -135,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): self.discriminator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') + tracker.add('discriminator', self.discriminator) self.discriminator_optimizer.step() # Train the generator @@ -154,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): self.generator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.generator, 'generator') + tracker.add('generator', self.generator) self.generator_optimizer.step() tracker.save() diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index 11fe0fd8..b736f67f 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -34,7 +34,6 @@ from torch.utils.data import Dataset, DataLoader import einops from labml import lab, experiment, tracker, monit -from labml.utils import pytorch as pytorch_utils from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module from labml_helpers.optimizer import OptimizerConfigs @@ -586,8 +585,7 @@ class Configs(TrainValidConfigs): loss.backward() # Log model parameters and gradients if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.encoder, 'encoder') - pytorch_utils.store_model_indicators(self.decoder, 'decoder') + tracker.add(encoder=self.encoder, decoder=self.decoder) # Clip gradients nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip) nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) diff --git a/labml_nn/transformers/knn/train_model.py b/labml_nn/transformers/knn/train_model.py index e159ff36..2e42d967 100644 --- a/labml_nn/transformers/knn/train_model.py +++ b/labml_nn/transformers/knn/train_model.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn from torchtext.data.utils import get_tokenizer -import labml.utils.pytorch as pytorch_utils from labml import lab, experiment, monit, tracker, logger from labml.configs import option from labml.logger import Text @@ -174,7 +173,7 @@ class Configs(SimpleTrainValidConfigs): loss.backward() self.optimizer.step() if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.model) + tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save()