mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	♻️ tracker for tracking models
This commit is contained in:
		| @ -10,11 +10,10 @@ import torch.nn as nn | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| import torch.utils.data | import torch.utils.data | ||||||
|  |  | ||||||
| import labml.utils.pytorch as pytorch_utils |  | ||||||
| from labml import experiment, tracker | from labml import experiment, tracker | ||||||
| from labml.configs import option | from labml.configs import option | ||||||
| from labml_helpers.datasets.mnist import MNISTConfigs | from labml_helpers.datasets.mnist import MNISTConfigs | ||||||
| from labml_helpers.metrics.accuracy import Accuracy, AccuracyDirect | from labml_helpers.metrics.accuracy import AccuracyDirect | ||||||
| from labml_helpers.module import Module | from labml_helpers.module import Module | ||||||
| from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex | ||||||
| from labml_nn.capsule_networks import Squash, Router, MarginLoss | from labml_nn.capsule_networks import Squash, Router, MarginLoss | ||||||
| @ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs): | |||||||
|             self.optimizer.step() |             self.optimizer.step() | ||||||
|             # Log parameters and gradients |             # Log parameters and gradients | ||||||
|             if batch_idx.is_last: |             if batch_idx.is_last: | ||||||
|                 pytorch_utils.store_model_indicators(self.model) |                 tracker.add('model', self.model) | ||||||
|             self.optimizer.zero_grad() |             self.optimizer.zero_grad() | ||||||
|  |  | ||||||
|             tracker.save() |             tracker.save() | ||||||
|  | |||||||
| @ -9,9 +9,8 @@ import torch.nn as nn | |||||||
| import torch.utils.data | import torch.utils.data | ||||||
| from torchvision import transforms | from torchvision import transforms | ||||||
|  |  | ||||||
| import labml.utils.pytorch as pytorch_utils |  | ||||||
| from labml import tracker, monit, experiment | from labml import tracker, monit, experiment | ||||||
| from labml.configs import option, calculate | from labml.configs import option | ||||||
| from labml_helpers.datasets.mnist import MNISTConfigs | from labml_helpers.datasets.mnist import MNISTConfigs | ||||||
| from labml_helpers.device import DeviceConfigs | from labml_helpers.device import DeviceConfigs | ||||||
| from labml_helpers.module import Module | from labml_helpers.module import Module | ||||||
| @ -135,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): | |||||||
|                     self.discriminator_optimizer.zero_grad() |                     self.discriminator_optimizer.zero_grad() | ||||||
|                     loss.backward() |                     loss.backward() | ||||||
|                     if batch_idx.is_last: |                     if batch_idx.is_last: | ||||||
|                         pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') |                         tracker.add('discriminator', self.discriminator) | ||||||
|                     self.discriminator_optimizer.step() |                     self.discriminator_optimizer.step() | ||||||
|  |  | ||||||
|         # Train the generator |         # Train the generator | ||||||
| @ -154,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): | |||||||
|                 self.generator_optimizer.zero_grad() |                 self.generator_optimizer.zero_grad() | ||||||
|                 loss.backward() |                 loss.backward() | ||||||
|                 if batch_idx.is_last: |                 if batch_idx.is_last: | ||||||
|                     pytorch_utils.store_model_indicators(self.generator, 'generator') |                     tracker.add('generator', self.generator) | ||||||
|                 self.generator_optimizer.step() |                 self.generator_optimizer.step() | ||||||
|  |  | ||||||
|         tracker.save() |         tracker.save() | ||||||
|  | |||||||
| @ -34,7 +34,6 @@ from torch.utils.data import Dataset, DataLoader | |||||||
|  |  | ||||||
| import einops | import einops | ||||||
| from labml import lab, experiment, tracker, monit | from labml import lab, experiment, tracker, monit | ||||||
| from labml.utils import pytorch as pytorch_utils |  | ||||||
| from labml_helpers.device import DeviceConfigs | from labml_helpers.device import DeviceConfigs | ||||||
| from labml_helpers.module import Module | from labml_helpers.module import Module | ||||||
| from labml_helpers.optimizer import OptimizerConfigs | from labml_helpers.optimizer import OptimizerConfigs | ||||||
| @ -586,8 +585,7 @@ class Configs(TrainValidConfigs): | |||||||
|                 loss.backward() |                 loss.backward() | ||||||
|                 # Log model parameters and gradients |                 # Log model parameters and gradients | ||||||
|                 if batch_idx.is_last: |                 if batch_idx.is_last: | ||||||
|                     pytorch_utils.store_model_indicators(self.encoder, 'encoder') |                     tracker.add(encoder=self.encoder, decoder=self.decoder) | ||||||
|                     pytorch_utils.store_model_indicators(self.decoder, 'decoder') |  | ||||||
|                 # Clip gradients |                 # Clip gradients | ||||||
|                 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip) |                 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip) | ||||||
|                 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) |                 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) | ||||||
|  | |||||||
| @ -10,7 +10,6 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from torchtext.data.utils import get_tokenizer | from torchtext.data.utils import get_tokenizer | ||||||
|  |  | ||||||
| import labml.utils.pytorch as pytorch_utils |  | ||||||
| from labml import lab, experiment, monit, tracker, logger | from labml import lab, experiment, monit, tracker, logger | ||||||
| from labml.configs import option | from labml.configs import option | ||||||
| from labml.logger import Text | from labml.logger import Text | ||||||
| @ -174,7 +173,7 @@ class Configs(SimpleTrainValidConfigs): | |||||||
|             loss.backward() |             loss.backward() | ||||||
|             self.optimizer.step() |             self.optimizer.step() | ||||||
|             if batch_idx.is_last: |             if batch_idx.is_last: | ||||||
|                 pytorch_utils.store_model_indicators(self.model) |                 tracker.add('model', self.model) | ||||||
|             self.optimizer.zero_grad() |             self.optimizer.zero_grad() | ||||||
|  |  | ||||||
|         tracker.save() |         tracker.save() | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri