mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 10:18:50 +08:00
♻️ tracker for tracking models
This commit is contained in:
@ -10,11 +10,10 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
import labml.utils.pytorch as pytorch_utils
|
|
||||||
from labml import experiment, tracker
|
from labml import experiment, tracker
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||||
from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect
|
from labml_helpers.metrics.accuracy import AccuracyDirect
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
|
||||||
from labml_nn.capsule_networks import Squash, Router, MarginLoss
|
from labml_nn.capsule_networks import Squash, Router, MarginLoss
|
||||||
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
# Log parameters and gradients
|
# Log parameters and gradients
|
||||||
if batch_idx.is_last:
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.model)
|
tracker.add('model', self.model)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
|||||||
@ -9,9 +9,8 @@ import torch.nn as nn
|
|||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
import labml.utils.pytorch as pytorch_utils
|
|
||||||
from labml import tracker, monit, experiment
|
from labml import tracker, monit, experiment
|
||||||
from labml.configs import option, calculate
|
from labml.configs import option
|
||||||
from labml_helpers.datasets.mnist import MNISTConfigs
|
from labml_helpers.datasets.mnist import MNISTConfigs
|
||||||
from labml_helpers.device import DeviceConfigs
|
from labml_helpers.device import DeviceConfigs
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
@ -135,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
self.discriminator_optimizer.zero_grad()
|
self.discriminator_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if batch_idx.is_last:
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
tracker.add('discriminator', self.discriminator)
|
||||||
self.discriminator_optimizer.step()
|
self.discriminator_optimizer.step()
|
||||||
|
|
||||||
# Train the generator
|
# Train the generator
|
||||||
@ -154,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
self.generator_optimizer.zero_grad()
|
self.generator_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if batch_idx.is_last:
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
tracker.add('generator', self.generator)
|
||||||
self.generator_optimizer.step()
|
self.generator_optimizer.step()
|
||||||
|
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
|||||||
@ -34,7 +34,6 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
from labml import lab, experiment, tracker, monit
|
from labml import lab, experiment, tracker, monit
|
||||||
from labml.utils import pytorch as pytorch_utils
|
|
||||||
from labml_helpers.device import DeviceConfigs
|
from labml_helpers.device import DeviceConfigs
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.optimizer import OptimizerConfigs
|
from labml_helpers.optimizer import OptimizerConfigs
|
||||||
@ -586,8 +585,7 @@ class Configs(TrainValidConfigs):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
# Log model parameters and gradients
|
# Log model parameters and gradients
|
||||||
if batch_idx.is_last:
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.encoder, 'encoder')
|
tracker.add(encoder=self.encoder, decoder=self.decoder)
|
||||||
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
|
|
||||||
# Clip gradients
|
# Clip gradients
|
||||||
nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
|
nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
|
||||||
nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
|
nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchtext.data.utils import get_tokenizer
|
from torchtext.data.utils import get_tokenizer
|
||||||
|
|
||||||
import labml.utils.pytorch as pytorch_utils
|
|
||||||
from labml import lab, experiment, monit, tracker, logger
|
from labml import lab, experiment, monit, tracker, logger
|
||||||
from labml.configs import option
|
from labml.configs import option
|
||||||
from labml.logger import Text
|
from labml.logger import Text
|
||||||
@ -174,7 +173,7 @@ class Configs(SimpleTrainValidConfigs):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
if batch_idx.is_last:
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.model)
|
tracker.add('model', self.model)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
tracker.save()
|
tracker.save()
|
||||||
|
|||||||
Reference in New Issue
Block a user