♻️ tracker for tracking models

This commit is contained in:
Varuna Jayasiri
2020-12-02 13:58:24 +05:30
parent 5238b5d432
commit 6f6b185aae
4 changed files with 7 additions and 12 deletions

View File

@ -10,11 +10,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
import labml.utils.pytorch as pytorch_utils
from labml import experiment, tracker from labml import experiment, tracker
from labml.configs import option from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect from labml_helpers.metrics.accuracy import AccuracyDirect
from labml_helpers.module import Module from labml_helpers.module import Module
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
from labml_nn.capsule_networks import Squash, Router, MarginLoss from labml_nn.capsule_networks import Squash, Router, MarginLoss
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
self.optimizer.step() self.optimizer.step()
# Log parameters and gradients # Log parameters and gradients
if batch_idx.is_last: if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.model) tracker.add('model', self.model)
self.optimizer.zero_grad() self.optimizer.zero_grad()
tracker.save() tracker.save()

View File

@ -9,9 +9,8 @@ import torch.nn as nn
import torch.utils.data import torch.utils.data
from torchvision import transforms from torchvision import transforms
import labml.utils.pytorch as pytorch_utils
from labml import tracker, monit, experiment from labml import tracker, monit, experiment
from labml.configs import option, calculate from labml.configs import option
from labml_helpers.datasets.mnist import MNISTConfigs from labml_helpers.datasets.mnist import MNISTConfigs
from labml_helpers.device import DeviceConfigs from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module from labml_helpers.module import Module
@ -135,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
self.discriminator_optimizer.zero_grad() self.discriminator_optimizer.zero_grad()
loss.backward() loss.backward()
if batch_idx.is_last: if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') tracker.add('discriminator', self.discriminator)
self.discriminator_optimizer.step() self.discriminator_optimizer.step()
# Train the generator # Train the generator
@ -154,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
self.generator_optimizer.zero_grad() self.generator_optimizer.zero_grad()
loss.backward() loss.backward()
if batch_idx.is_last: if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.generator, 'generator') tracker.add('generator', self.generator)
self.generator_optimizer.step() self.generator_optimizer.step()
tracker.save() tracker.save()

View File

@ -34,7 +34,6 @@ from torch.utils.data import Dataset, DataLoader
import einops import einops
from labml import lab, experiment, tracker, monit from labml import lab, experiment, tracker, monit
from labml.utils import pytorch as pytorch_utils
from labml_helpers.device import DeviceConfigs from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs from labml_helpers.optimizer import OptimizerConfigs
@ -586,8 +585,7 @@ class Configs(TrainValidConfigs):
loss.backward() loss.backward()
# Log model parameters and gradients # Log model parameters and gradients
if batch_idx.is_last: if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.encoder, 'encoder') tracker.add(encoder=self.encoder, decoder=self.decoder)
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
# Clip gradients # Clip gradients
nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip) nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

View File

@ -10,7 +10,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
import labml.utils.pytorch as pytorch_utils
from labml import lab, experiment, monit, tracker, logger from labml import lab, experiment, monit, tracker, logger
from labml.configs import option from labml.configs import option
from labml.logger import Text from labml.logger import Text
@ -174,7 +173,7 @@ class Configs(SimpleTrainValidConfigs):
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
if batch_idx.is_last: if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.model) tracker.add('model', self.model)
self.optimizer.zero_grad() self.optimizer.zero_grad()
tracker.save() tracker.save()