♻️ tracker for tracking models

This commit is contained in:
Varuna Jayasiri
2020-12-02 13:58:24 +05:30
parent 5238b5d432
commit 6f6b185aae
4 changed files with 7 additions and 12 deletions

View File

@ -10,11 +10,10 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import labml.utils.pytorch as pytorch_utils
from labml import experiment, tracker
from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect
from labml_helpers.metrics.accuracy import AccuracyDirect
from labml_helpers.module import Module
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.capsule_networks import Squash, Router, MarginLoss
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
self.optimizer.step()
# Log parameters and gradients
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.model)
tracker.add('model', self.model)
self.optimizer.zero_grad()
tracker.save()

View File

@ -9,9 +9,8 @@ import torch.nn as nn
import torch.utils.data
from torchvision import transforms
import labml.utils.pytorch as pytorch_utils
from labml import tracker, monit, experiment
from labml.configs import option, calculate
from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
@ -135,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
self.discriminator_optimizer.zero_grad()
loss.backward()
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
tracker.add('discriminator', self.discriminator)
self.discriminator_optimizer.step()
# Train the generator
@ -154,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
self.generator_optimizer.zero_grad()
loss.backward()
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.generator, 'generator')
tracker.add('generator', self.generator)
self.generator_optimizer.step()
tracker.save()

View File

@ -34,7 +34,6 @@ from torch.utils.data import Dataset, DataLoader
import einops
from labml import lab, experiment, tracker, monit
from labml.utils import pytorch as pytorch_utils
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
@ -586,8 +585,7 @@ class Configs(TrainValidConfigs):
loss.backward()
# Log model parameters and gradients
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.encoder, 'encoder')
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
tracker.add(encoder=self.encoder, decoder=self.decoder)
# Clip gradients
nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

View File

@ -10,7 +10,6 @@ import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
import labml.utils.pytorch as pytorch_utils
from labml import lab, experiment, monit, tracker, logger
from labml.configs import option
from labml.logger import Text
@ -174,7 +173,7 @@ class Configs(SimpleTrainValidConfigs):
loss.backward()
self.optimizer.step()
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.model)
tracker.add('model', self.model)
self.optimizer.zero_grad()
tracker.save()