diff --git a/labml_nn/capsule_networks/mnist.py b/labml_nn/capsule_networks/mnist.py index 72ed7cda..6ed840c9 100644 --- a/labml_nn/capsule_networks/mnist.py +++ b/labml_nn/capsule_networks/mnist.py @@ -124,8 +124,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs): tracker.add_global_step(len(data)) # Whether to log activations - is_log_activations = batch_idx.is_interval(self.log_activations_batches) - with self.mode.update(is_log_activations=is_log_activations): + with self.mode.update(is_log_activations=batch_idx.is_last): # Run the model caps, reconstructions, pred = self.model(data) @@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs): self.optimizer.step() # Log parameters and gradients - if batch_idx.is_interval(self.log_params_updates): + if batch_idx.is_last: pytorch_utils.store_model_indicators(self.model) self.optimizer.zero_grad() diff --git a/labml_nn/gan/simple_mnist_experiment.py b/labml_nn/gan/simple_mnist_experiment.py index 10c1895d..9e109120 100644 --- a/labml_nn/gan/simple_mnist_experiment.py +++ b/labml_nn/gan/simple_mnist_experiment.py @@ -93,8 +93,6 @@ class Configs(MNISTConfigs, TrainValidConfigs): label_smoothing: float = 0.2 discriminator_k: int = 1 - log_params_updates: int = 2 ** 32 # 0 if not - def init(self): self.state_modules = [] self.generator = Generator().to(self.device) @@ -136,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): if self.mode.is_train: self.discriminator_optimizer.zero_grad() loss.backward() - if batch_idx.is_interval(self.log_params_updates): + if batch_idx.is_last: pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') self.discriminator_optimizer.step() @@ -155,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs): if self.mode.is_train: self.generator_optimizer.zero_grad() loss.backward() - if batch_idx.is_interval(self.log_params_updates): + if batch_idx.is_last: pytorch_utils.store_model_indicators(self.generator, 'generator') self.generator_optimizer.step() diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index 8c366c7f..11fe0fd8 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -23,7 +23,7 @@ Took help from [PyTorch Sketch RNN](https://github.com/alexis-jacq/Pytorch-Sketc """ import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Any import numpy as np import torch @@ -34,13 +34,11 @@ from torch.utils.data import Dataset, DataLoader import einops from labml import lab, experiment, tracker, monit -from labml.configs import option from labml.utils import pytorch as pytorch_utils from labml_helpers.device import DeviceConfigs from labml_helpers.module import Module from labml_helpers.optimizer import OptimizerConfigs -from labml_helpers.train_valid import TrainValidConfigs, BatchStepProtocol, hook_model_outputs, \ - MODE_STATE +from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex class StrokesDataset(Dataset): @@ -452,53 +450,104 @@ class Sampler: plt.show() -class StrokesBatchStep(BatchStepProtocol): +class Configs(TrainValidConfigs): """ - ## Train/Validation modules + ## Configurations + + These are default configurations which can be later adjusted by passing a `dict`. """ - def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN, - optimizer: Optional[torch.optim.Adam], - kl_div_loss_weight: float, grad_clip: float): - self.grad_clip = grad_clip - self.kl_div_loss_weight = kl_div_loss_weight - self.encoder = encoder - self.decoder = decoder + # Device configurations to pick the device to run the experiment + device: torch.device = DeviceConfigs() + # + encoder: EncoderRNN + decoder: DecoderRNN + optimizer: optim.Adam + sampler: Sampler + + dataset_name: str + train_loader: DataLoader + valid_loader: DataLoader + train_dataset: StrokesDataset + valid_dataset: StrokesDataset + + # Encoder and decoder sizes + enc_hidden_size = 256 + dec_hidden_size = 512 + + # Batch size + batch_size = 100 + + # Number of features in $z$ + d_z = 128 + # Number of distributions in the mixture, $M$ + n_distributions = 20 + + # Weight of KL divergence loss, $w_{KL}$ + kl_div_loss_weight = 0.5 + # Gradient clipping + grad_clip = 1. + # Temperature $\tau$ for sampling + temperature = 0.4 + + # Filter out stroke sequences longer than $200$ + max_seq_length = 200 + + epochs = 100 + + kl_div_loss = KLDivLoss() + reconstruction_loss = ReconstructionLoss() + + def init(self): + # Initialize encoder & decoder + self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) + self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) + + # Set optimizer, things like type of optimizer and learning rate are configurable + optimizer = OptimizerConfigs() + optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) self.optimizer = optimizer - self.kl_div_loss = KLDivLoss() - self.reconstruction_loss = ReconstructionLoss() + + # Create sampler + self.sampler = Sampler(self.encoder, self.decoder) + + # `npz` file path is `data/sketch/[DATASET NAME].npz` + path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' + # Load the numpy file. + dataset = np.load(str(path), encoding='latin1', allow_pickle=True) + + # Create training dataset + self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) + # Create validation dataset + self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) + + # Create training data loader + self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) + # Create validation data loader + self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) # Add hooks to monitor layer outputs on Tensorboard - hook_model_outputs(self.encoder, 'encoder') - hook_model_outputs(self.decoder, 'decoder') + hook_model_outputs(self.mode, self.encoder, 'encoder') + hook_model_outputs(self.mode, self.decoder, 'decoder') # Configure the tracker to print the total train/validation loss tracker.set_scalar("loss.total.*", True) - def prepare_for_iteration(self): - """ - Set models for training/evaluation - """ - if MODE_STATE.is_train: - self.encoder.train() - self.decoder.train() - else: - self.encoder.eval() - self.decoder.eval() + self.state_modules = [] - def process(self, batch: any, state: any): - """ - Process a batch - """ - data, mask = batch + def step(self, batch: Any, batch_idx: BatchIndex): + self.encoder.train(self.mode.is_train) + self.decoder.train(self.mode.is_train) - # Get model device - device = self.encoder.device # Move `data` and `mask` to device and swap the sequence and batch dimensions. # `data` will have shape `[seq_len, batch_size, 5]` and # `mask` will have shape `[seq_len, batch_size]`. - data = data.to(device).transpose(0, 1) - mask = mask.to(device).transpose(0, 1) + data = batch[0].to(self.device).transpose(0, 1) + mask = batch[1].to(self.device).transpose(0, 1) + + # Increment step in training mode + if self.mode.is_train: + tracker.add_global_step(len(data)) # Encode the sequence of strokes with monit.section("encoder"): @@ -527,16 +576,16 @@ class StrokesBatchStep(BatchStepProtocol): tracker.add("loss.reconstruction.", reconstruction_loss) tracker.add("loss.total.", loss) - # Run optimizer - with monit.section('optimize'): - # Only if we are in training state - if MODE_STATE.is_train: + # Only if we are in training state + if self.mode.is_train: + # Run optimizer + with monit.section('optimize'): # Set `grad` to zero self.optimizer.zero_grad() # Compute gradients loss.backward() # Log model parameters and gradients - if MODE_STATE.is_log_parameters: + if batch_idx.is_last: pytorch_utils.store_model_indicators(self.encoder, 'encoder') pytorch_utils.store_model_indicators(self.decoder, 'decoder') # Clip gradients @@ -545,84 +594,7 @@ class StrokesBatchStep(BatchStepProtocol): # Optimize self.optimizer.step() - # - return {'samples': len(data)}, None - - -class Configs(TrainValidConfigs): - """ - ## Configurations - - These are default configurations which can be later adjusted by passing a `dict`. - """ - - # Device configurations to pick the device to run the experiment - device: torch.device = DeviceConfigs() - # - encoder: EncoderRNN - decoder: DecoderRNN - optimizer: optim.Adam = 'setup_all' - sampler: Sampler - - dataset_name: str - train_loader = 'setup_all' - valid_loader = 'setup_all' - train_dataset: StrokesDataset - valid_dataset: StrokesDataset - - batch_step = 'strokes_batch_step' - - # Encoder and decoder sizes - enc_hidden_size = 256 - dec_hidden_size = 512 - - # Batch size - batch_size = 100 - - # Number of features in $z$ - d_z = 128 - # Number of distributions in the mixture, $M$ - n_distributions = 20 - - # Weight of KL divergence loss, $w_{KL}$ - kl_div_loss_weight = 0.5 - # Gradient clipping - grad_clip = 1. - # Temperature $\tau$ for sampling - temperature = 0.4 - - # Filter out stroke sequences longer than $200$ - max_seq_length = 200 - - epochs = 100 - - def initialize(self): - # Initialize encoder & decoder - self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) - self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) - - # Set optimizer, things like type of optimizer and learning rate are configurable - optimizer = OptimizerConfigs() - optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) - self.optimizer = optimizer - - # Create sampler - self.sampler = Sampler(self.encoder, self.decoder) - - # `npz` file path is `data/sketch/[DATASET NAME].npz` - path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' - # Load the numpy file. - dataset = np.load(str(path), encoding='latin1', allow_pickle=True) - - # Create training dataset - self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) - # Create validation dataset - self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) - - # Create training data loader - self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) - # Create validation data loader - self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) + tracker.save() def sample(self): # Randomly pick a sample from validation dataset to encoder @@ -633,12 +605,6 @@ class Configs(TrainValidConfigs): self.sampler.sample(data, self.temperature) -@option(Configs.batch_step) -def strokes_batch_step(c: Configs): - """Set Strokes Training and Validation module""" - return StrokesBatchStep(c.encoder, c.decoder, c.optimizer, c.kl_div_loss_weight, c.grad_clip) - - def main(): configs = Configs() experiment.create(name="sketch_rnn") @@ -655,8 +621,6 @@ def main(): 'inner_iterations': 10 }) - configs.initialize() - with experiment.start(): # Run the experiment configs.run()