From 6f6b185aae8df6ee8e47fc76a5dcf4b4899b64a1 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Wed, 2 Dec 2020 13:58:24 +0530 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20=20tracker=20for=20trackin?= =?UTF-8?q?g=20models?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/capsule_networks/mnist.py | 5 ++--- labml_nn/gan/simple_mnist_experiment.py | 7 +++---- labml_nn/sketch_rnn/__init__.py | 4 +--- labml_nn/transformers/knn/train_model.py | 3 +-- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py index 6ed840c9..ea4df6ab 100644 --- a/labml_nn/capsule_networks/mnist.py +++ b/labml_nn/capsule_networks/mnist.py @@ -10,11 +10,10 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.data -import labml.utils.pytorch as pytorch_utils from labml import experiment, tracker from labml.configs import option from labml_helpers.datasets.mnist import MNISTConfigs -from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect +from labml_helpers.metrics.accuracy import AccuracyDirect from labml_helpers.module import Module from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_nn.capsule_networks import Squash, Router, MarginLoss @@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs): self.optimizer.step() # Log parameters and gradients if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.model) + tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save() diff --git a/labml_nn/gan/simple_mnist_experiment.py b/labml_nn/gan/simple_mnist_experiment.py index 9e109120..b801946a 100644 --- a/labml_nn/gan/simple_mnist_experiment.py +++ b/labml_nn/gan/simple_mnist_experiment.py @@ -9,9 +9,8 @@ import torch.nn as nn import torch.utils.data from torchvision import transforms -import labml.utils.pytorch as pytorch_utils from labml import tracker, monit, experiment -from labml.configs import option, calculate +from labml.configs import option from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module @@ -135,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): self.discriminator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') + tracker.add('discriminator', self.discriminator) self.discriminator_optimizer.step() # Train the generator @@ -154,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): self.generator_optimizer.zero_grad() loss.backward() if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.generator, 'generator') + tracker.add('generator', self.generator) self.generator_optimizer.step() tracker.save() diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index 11fe0fd8..b736f67f 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -34,7 +34,6 @@ from torch.utils.data import Dataset, DataLoader import einops from labml import lab, experiment, tracker, monit -from labml.utils import pytorch as pytorch_utils from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module from labml_helpers.optimizer import OptimizerConfigs @@ -586,8 +585,7 @@ class Configs(TrainValidConfigs): loss.backward() # Log model parameters and gradients if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.encoder, 'encoder') - pytorch_utils.store_model_indicators(self.decoder, 'decoder') + tracker.add(encoder=self.encoder, decoder=self.decoder) # Clip gradients nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip) nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) diff --git a/labml_nn/transformers/knn/train_model.py b/labml_nn/transformers/knn/train_model.py index e159ff36..2e42d967 100644 --- a/labml_nn/transformers/knn/train_model.py +++ b/labml_nn/transformers/knn/train_model.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn from torchtext.data.utils import get_tokenizer -import labml.utils.pytorch as pytorch_utils from labml import lab, experiment, monit, tracker, logger from labml.configs import option from labml.logger import Text @@ -174,7 +173,7 @@ class Configs(SimpleTrainValidConfigs): loss.backward() self.optimizer.step() if batch_idx.is_last: - pytorch_utils.store_model_indicators(self.model) + tracker.add('model', self.model) self.optimizer.zero_grad() tracker.save()