mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-18 11:52:00 +08:00
experiemnt logs
This commit is contained in:
@ -74,6 +74,16 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
||||
# Validation data loader
|
||||
valid_loader: DataLoader = 'ag_news'
|
||||
|
||||
# Whether to log model parameters and gradients (once per epoch).
|
||||
# These are summarized stats per layer, but it could still lead
|
||||
# to many indicators for very deep networks.
|
||||
is_log_model_params_grads: bool = False
|
||||
|
||||
# Whether to log model activations (once per epoch).
|
||||
# These are summarized stats per layer, but it could still lead
|
||||
# to many indicators for very deep networks.
|
||||
is_log_model_activations: bool = False
|
||||
|
||||
def init(self):
|
||||
"""
|
||||
### Initialization
|
||||
@ -102,7 +112,7 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
||||
tracker.add_global_step(data.shape[1])
|
||||
|
||||
# Whether to capture model outputs
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):
|
||||
# Get model outputs.
|
||||
# It's returning a tuple for states when using RNNs.
|
||||
# This is not implemented yet. 😜
|
||||
@ -125,7 +135,7 @@ class NLPClassificationConfigs(TrainValidConfigs):
|
||||
# Take optimizer step
|
||||
self.optimizer.step()
|
||||
# Log the model parameters and gradients on last batch of every epoch
|
||||
if batch_idx.is_last:
|
||||
if batch_idx.is_last and self.is_log_model_params_grads:
|
||||
tracker.add('model', self.model)
|
||||
# Clear the gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
Reference in New Issue
Block a user