mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
♻️ tracker for tracking models
This commit is contained in:
@ -10,11 +10,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
|
||||
import labml.utils.pytorch as pytorch_utils
|
||||
from labml import experiment, tracker
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect
|
||||
from labml_helpers.metrics.accuracy import AccuracyDirect
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||
from labml_nn.capsule_networks import Squash, Router, MarginLoss
|
||||
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
||||
self.optimizer.step()
|
||||
# Log parameters and gradients
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.model)
|
||||
tracker.add('model', self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
tracker.save()
|
||||
|
||||
@ -9,9 +9,8 @@ import torch.nn as nn
|
||||
import torch.utils.data
|
||||
from torchvision import transforms
|
||||
|
||||
import labml.utils.pytorch as pytorch_utils
|
||||
from labml import tracker, monit, experiment
|
||||
from labml.configs import option, calculate
|
||||
from labml.configs import option
|
||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
@ -135,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
self.discriminator_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
||||
tracker.add('discriminator', self.discriminator)
|
||||
self.discriminator_optimizer.step()
|
||||
|
||||
# Train the generator
|
||||
@ -154,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
self.generator_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
||||
tracker.add('generator', self.generator)
|
||||
self.generator_optimizer.step()
|
||||
|
||||
tracker.save()
|
||||
|
||||
@ -34,7 +34,6 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import einops
|
||||
from labml import lab, experiment, tracker, monit
|
||||
from labml.utils import pytorch as pytorch_utils
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
@ -586,8 +585,7 @@ class Configs(TrainValidConfigs):
|
||||
loss.backward()
|
||||
# Log model parameters and gradients
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.encoder, 'encoder')
|
||||
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
|
||||
tracker.add(encoder=self.encoder, decoder=self.decoder)
|
||||
# Clip gradients
|
||||
nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
|
||||
nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
|
||||
|
||||
@ -10,7 +10,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torchtext.data.utils import get_tokenizer
|
||||
|
||||
import labml.utils.pytorch as pytorch_utils
|
||||
from labml import lab, experiment, monit, tracker, logger
|
||||
from labml.configs import option
|
||||
from labml.logger import Text
|
||||
@ -174,7 +173,7 @@ class Configs(SimpleTrainValidConfigs):
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.model)
|
||||
tracker.add('model', self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
tracker.save()
|
||||
|
||||
Reference in New Issue
Block a user