sketch rnn step

This commit is contained in:
Varuna Jayasiri
2020-11-18 10:52:50 +05:30
parent ae7218774b
commit 2a8631dae1
3 changed files with 95 additions and 134 deletions

View File

@ -124,8 +124,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
tracker.add_global_step(len(data)) tracker.add_global_step(len(data))
# Whether to log activations # Whether to log activations
is_log_activations = batch_idx.is_interval(self.log_activations_batches) with self.mode.update(is_log_activations=batch_idx.is_last):
with self.mode.update(is_log_activations=is_log_activations):
# Run the model # Run the model
caps, reconstructions, pred = self.model(data) caps, reconstructions, pred = self.model(data)
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
self.optimizer.step() self.optimizer.step()
# Log parameters and gradients # Log parameters and gradients
if batch_idx.is_interval(self.log_params_updates): if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.model) pytorch_utils.store_model_indicators(self.model)
self.optimizer.zero_grad() self.optimizer.zero_grad()

View File

@ -93,8 +93,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
label_smoothing: float = 0.2 label_smoothing: float = 0.2
discriminator_k: int = 1 discriminator_k: int = 1
log_params_updates: int = 2 ** 32 # 0 if not
def init(self): def init(self):
self.state_modules = [] self.state_modules = []
self.generator = Generator().to(self.device) self.generator = Generator().to(self.device)
@ -136,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
self.discriminator_optimizer.zero_grad() self.discriminator_optimizer.zero_grad()
loss.backward() loss.backward()
if batch_idx.is_interval(self.log_params_updates): if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator') pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
self.discriminator_optimizer.step() self.discriminator_optimizer.step()
@ -155,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
if self.mode.is_train: if self.mode.is_train:
self.generator_optimizer.zero_grad() self.generator_optimizer.zero_grad()
loss.backward() loss.backward()
if batch_idx.is_interval(self.log_params_updates): if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.generator, 'generator') pytorch_utils.store_model_indicators(self.generator, 'generator')
self.generator_optimizer.step() self.generator_optimizer.step()

View File

@ -23,7 +23,7 @@ Took help from [PyTorch Sketch RNN](https://github.com/alexis-jacq/Pytorch-Sketc
""" """
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple, Any
import numpy as np import numpy as np
import torch import torch
@ -34,13 +34,11 @@ from torch.utils.data import Dataset, DataLoader
import einops import einops
from labml import lab, experiment, tracker, monit from labml import lab, experiment, tracker, monit
from labml.configs import option
from labml.utils import pytorch as pytorch_utils from labml.utils import pytorch as pytorch_utils
from labml_helpers.device import DeviceConfigs from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import TrainValidConfigs, BatchStepProtocol, hook_model_outputs, \ from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
MODE_STATE
class StrokesDataset(Dataset): class StrokesDataset(Dataset):
@ -452,53 +450,104 @@ class Sampler:
plt.show() plt.show()
class StrokesBatchStep(BatchStepProtocol): class Configs(TrainValidConfigs):
""" """
## Train/Validation modules ## Configurations
These are default configurations which can be later adjusted by passing a `dict`.
""" """
def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN, # Device configurations to pick the device to run the experiment
optimizer: Optional[torch.optim.Adam], device: torch.device = DeviceConfigs()
kl_div_loss_weight: float, grad_clip: float): #
self.grad_clip = grad_clip encoder: EncoderRNN
self.kl_div_loss_weight = kl_div_loss_weight decoder: DecoderRNN
self.encoder = encoder optimizer: optim.Adam
self.decoder = decoder sampler: Sampler
dataset_name: str
train_loader: DataLoader
valid_loader: DataLoader
train_dataset: StrokesDataset
valid_dataset: StrokesDataset
# Encoder and decoder sizes
enc_hidden_size = 256
dec_hidden_size = 512
# Batch size
batch_size = 100
# Number of features in $z$
d_z = 128
# Number of distributions in the mixture, $M$
n_distributions = 20
# Weight of KL divergence loss, $w_{KL}$
kl_div_loss_weight = 0.5
# Gradient clipping
grad_clip = 1.
# Temperature $\tau$ for sampling
temperature = 0.4
# Filter out stroke sequences longer than $200$
max_seq_length = 200
epochs = 100
kl_div_loss = KLDivLoss()
reconstruction_loss = ReconstructionLoss()
def init(self):
# Initialize encoder & decoder
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
# Set optimizer, things like type of optimizer and learning rate are configurable
optimizer = OptimizerConfigs()
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
self.optimizer = optimizer self.optimizer = optimizer
self.kl_div_loss = KLDivLoss()
self.reconstruction_loss = ReconstructionLoss() # Create sampler
self.sampler = Sampler(self.encoder, self.decoder)
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
# Create training dataset
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
# Create validation dataset
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
# Create training data loader
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
# Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
# Add hooks to monitor layer outputs on Tensorboard # Add hooks to monitor layer outputs on Tensorboard
hook_model_outputs(self.encoder, 'encoder') hook_model_outputs(self.mode, self.encoder, 'encoder')
hook_model_outputs(self.decoder, 'decoder') hook_model_outputs(self.mode, self.decoder, 'decoder')
# Configure the tracker to print the total train/validation loss # Configure the tracker to print the total train/validation loss
tracker.set_scalar("loss.total.*", True) tracker.set_scalar("loss.total.*", True)
def prepare_for_iteration(self): self.state_modules = []
"""
Set models for training/evaluation
"""
if MODE_STATE.is_train:
self.encoder.train()
self.decoder.train()
else:
self.encoder.eval()
self.decoder.eval()
def process(self, batch: any, state: any): def step(self, batch: Any, batch_idx: BatchIndex):
""" self.encoder.train(self.mode.is_train)
Process a batch self.decoder.train(self.mode.is_train)
"""
data, mask = batch
# Get model device
device = self.encoder.device
# Move `data` and `mask` to device and swap the sequence and batch dimensions. # Move `data` and `mask` to device and swap the sequence and batch dimensions.
# `data` will have shape `[seq_len, batch_size, 5]` and # `data` will have shape `[seq_len, batch_size, 5]` and
# `mask` will have shape `[seq_len, batch_size]`. # `mask` will have shape `[seq_len, batch_size]`.
data = data.to(device).transpose(0, 1) data = batch[0].to(self.device).transpose(0, 1)
mask = mask.to(device).transpose(0, 1) mask = batch[1].to(self.device).transpose(0, 1)
# Increment step in training mode
if self.mode.is_train:
tracker.add_global_step(len(data))
# Encode the sequence of strokes # Encode the sequence of strokes
with monit.section("encoder"): with monit.section("encoder"):
@ -527,16 +576,16 @@ class StrokesBatchStep(BatchStepProtocol):
tracker.add("loss.reconstruction.", reconstruction_loss) tracker.add("loss.reconstruction.", reconstruction_loss)
tracker.add("loss.total.", loss) tracker.add("loss.total.", loss)
# Run optimizer # Only if we are in training state
with monit.section('optimize'): if self.mode.is_train:
# Only if we are in training state # Run optimizer
if MODE_STATE.is_train: with monit.section('optimize'):
# Set `grad` to zero # Set `grad` to zero
self.optimizer.zero_grad() self.optimizer.zero_grad()
# Compute gradients # Compute gradients
loss.backward() loss.backward()
# Log model parameters and gradients # Log model parameters and gradients
if MODE_STATE.is_log_parameters: if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.encoder, 'encoder') pytorch_utils.store_model_indicators(self.encoder, 'encoder')
pytorch_utils.store_model_indicators(self.decoder, 'decoder') pytorch_utils.store_model_indicators(self.decoder, 'decoder')
# Clip gradients # Clip gradients
@ -545,84 +594,7 @@ class StrokesBatchStep(BatchStepProtocol):
# Optimize # Optimize
self.optimizer.step() self.optimizer.step()
# tracker.save()
return {'samples': len(data)}, None
class Configs(TrainValidConfigs):
"""
## Configurations
These are default configurations which can be later adjusted by passing a `dict`.
"""
# Device configurations to pick the device to run the experiment
device: torch.device = DeviceConfigs()
#
encoder: EncoderRNN
decoder: DecoderRNN
optimizer: optim.Adam = 'setup_all'
sampler: Sampler
dataset_name: str
train_loader = 'setup_all'
valid_loader = 'setup_all'
train_dataset: StrokesDataset
valid_dataset: StrokesDataset
batch_step = 'strokes_batch_step'
# Encoder and decoder sizes
enc_hidden_size = 256
dec_hidden_size = 512
# Batch size
batch_size = 100
# Number of features in $z$
d_z = 128
# Number of distributions in the mixture, $M$
n_distributions = 20
# Weight of KL divergence loss, $w_{KL}$
kl_div_loss_weight = 0.5
# Gradient clipping
grad_clip = 1.
# Temperature $\tau$ for sampling
temperature = 0.4
# Filter out stroke sequences longer than $200$
max_seq_length = 200
epochs = 100
def initialize(self):
# Initialize encoder & decoder
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
# Set optimizer, things like type of optimizer and learning rate are configurable
optimizer = OptimizerConfigs()
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
self.optimizer = optimizer
# Create sampler
self.sampler = Sampler(self.encoder, self.decoder)
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
# Create training dataset
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
# Create validation dataset
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
# Create training data loader
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
# Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
def sample(self): def sample(self):
# Randomly pick a sample from validation dataset to encoder # Randomly pick a sample from validation dataset to encoder
@ -633,12 +605,6 @@ class Configs(TrainValidConfigs):
self.sampler.sample(data, self.temperature) self.sampler.sample(data, self.temperature)
@option(Configs.batch_step)
def strokes_batch_step(c: Configs):
"""Set Strokes Training and Validation module"""
return StrokesBatchStep(c.encoder, c.decoder, c.optimizer, c.kl_div_loss_weight, c.grad_clip)
def main(): def main():
configs = Configs() configs = Configs()
experiment.create(name="sketch_rnn") experiment.create(name="sketch_rnn")
@ -655,8 +621,6 @@ def main():
'inner_iterations': 10 'inner_iterations': 10
}) })
configs.initialize()
with experiment.start(): with experiment.start():
# Run the experiment # Run the experiment
configs.run() configs.run()