sketch rnn step

This commit is contained in:
Varuna Jayasiri
2020-11-18 10:52:50 +05:30
parent ae7218774b
commit 2a8631dae1
3 changed files with 95 additions and 134 deletions

View File

@ -124,8 +124,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
tracker.add_global_step(len(data))
# Whether to log activations
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
with self.mode.update(is_log_activations=is_log_activations):
with self.mode.update(is_log_activations=batch_idx.is_last):
# Run the model
caps, reconstructions, pred = self.model(data)
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
self.optimizer.step()
# Log parameters and gradients
if batch_idx.is_interval(self.log_params_updates):
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.model)
self.optimizer.zero_grad()

View File

@ -93,8 +93,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
label_smoothing: float = 0.2
discriminator_k: int = 1
log_params_updates: int = 2 ** 32 # 0 if not
def init(self):
self.state_modules = []
self.generator = Generator().to(self.device)
@ -136,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
if self.mode.is_train:
self.discriminator_optimizer.zero_grad()
loss.backward()
if batch_idx.is_interval(self.log_params_updates):
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
self.discriminator_optimizer.step()
@ -155,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
if self.mode.is_train:
self.generator_optimizer.zero_grad()
loss.backward()
if batch_idx.is_interval(self.log_params_updates):
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.generator, 'generator')
self.generator_optimizer.step()

View File

@ -23,7 +23,7 @@ Took help from [PyTorch Sketch RNN](https://github.com/alexis-jacq/Pytorch-Sketc
"""
import math
from typing import Optional, Tuple
from typing import Optional, Tuple, Any
import numpy as np
import torch
@ -34,13 +34,11 @@ from torch.utils.data import Dataset, DataLoader
import einops
from labml import lab, experiment, tracker, monit
from labml.configs import option
from labml.utils import pytorch as pytorch_utils
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_helpers.optimizer import OptimizerConfigs
from labml_helpers.train_valid import TrainValidConfigs, BatchStepProtocol, hook_model_outputs, \
MODE_STATE
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
class StrokesDataset(Dataset):
@ -452,53 +450,104 @@ class Sampler:
plt.show()
class StrokesBatchStep(BatchStepProtocol):
class Configs(TrainValidConfigs):
"""
## Train/Validation modules
## Configurations
These are default configurations which can be later adjusted by passing a `dict`.
"""
def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN,
optimizer: Optional[torch.optim.Adam],
kl_div_loss_weight: float, grad_clip: float):
self.grad_clip = grad_clip
self.kl_div_loss_weight = kl_div_loss_weight
self.encoder = encoder
self.decoder = decoder
# Device configurations to pick the device to run the experiment
device: torch.device = DeviceConfigs()
#
encoder: EncoderRNN
decoder: DecoderRNN
optimizer: optim.Adam
sampler: Sampler
dataset_name: str
train_loader: DataLoader
valid_loader: DataLoader
train_dataset: StrokesDataset
valid_dataset: StrokesDataset
# Encoder and decoder sizes
enc_hidden_size = 256
dec_hidden_size = 512
# Batch size
batch_size = 100
# Number of features in $z$
d_z = 128
# Number of distributions in the mixture, $M$
n_distributions = 20
# Weight of KL divergence loss, $w_{KL}$
kl_div_loss_weight = 0.5
# Gradient clipping
grad_clip = 1.
# Temperature $\tau$ for sampling
temperature = 0.4
# Filter out stroke sequences longer than $200$
max_seq_length = 200
epochs = 100
kl_div_loss = KLDivLoss()
reconstruction_loss = ReconstructionLoss()
def init(self):
# Initialize encoder & decoder
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
# Set optimizer, things like type of optimizer and learning rate are configurable
optimizer = OptimizerConfigs()
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
self.optimizer = optimizer
self.kl_div_loss = KLDivLoss()
self.reconstruction_loss = ReconstructionLoss()
# Create sampler
self.sampler = Sampler(self.encoder, self.decoder)
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
# Create training dataset
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
# Create validation dataset
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
# Create training data loader
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
# Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
# Add hooks to monitor layer outputs on Tensorboard
hook_model_outputs(self.encoder, 'encoder')
hook_model_outputs(self.decoder, 'decoder')
hook_model_outputs(self.mode, self.encoder, 'encoder')
hook_model_outputs(self.mode, self.decoder, 'decoder')
# Configure the tracker to print the total train/validation loss
tracker.set_scalar("loss.total.*", True)
def prepare_for_iteration(self):
"""
Set models for training/evaluation
"""
if MODE_STATE.is_train:
self.encoder.train()
self.decoder.train()
else:
self.encoder.eval()
self.decoder.eval()
self.state_modules = []
def process(self, batch: any, state: any):
"""
Process a batch
"""
data, mask = batch
def step(self, batch: Any, batch_idx: BatchIndex):
self.encoder.train(self.mode.is_train)
self.decoder.train(self.mode.is_train)
# Get model device
device = self.encoder.device
# Move `data` and `mask` to device and swap the sequence and batch dimensions.
# `data` will have shape `[seq_len, batch_size, 5]` and
# `mask` will have shape `[seq_len, batch_size]`.
data = data.to(device).transpose(0, 1)
mask = mask.to(device).transpose(0, 1)
data = batch[0].to(self.device).transpose(0, 1)
mask = batch[1].to(self.device).transpose(0, 1)
# Increment step in training mode
if self.mode.is_train:
tracker.add_global_step(len(data))
# Encode the sequence of strokes
with monit.section("encoder"):
@ -527,16 +576,16 @@ class StrokesBatchStep(BatchStepProtocol):
tracker.add("loss.reconstruction.", reconstruction_loss)
tracker.add("loss.total.", loss)
# Only if we are in training state
if self.mode.is_train:
# Run optimizer
with monit.section('optimize'):
# Only if we are in training state
if MODE_STATE.is_train:
# Set `grad` to zero
self.optimizer.zero_grad()
# Compute gradients
loss.backward()
# Log model parameters and gradients
if MODE_STATE.is_log_parameters:
if batch_idx.is_last:
pytorch_utils.store_model_indicators(self.encoder, 'encoder')
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
# Clip gradients
@ -545,84 +594,7 @@ class StrokesBatchStep(BatchStepProtocol):
# Optimize
self.optimizer.step()
#
return {'samples': len(data)}, None
class Configs(TrainValidConfigs):
"""
## Configurations
These are default configurations which can be later adjusted by passing a `dict`.
"""
# Device configurations to pick the device to run the experiment
device: torch.device = DeviceConfigs()
#
encoder: EncoderRNN
decoder: DecoderRNN
optimizer: optim.Adam = 'setup_all'
sampler: Sampler
dataset_name: str
train_loader = 'setup_all'
valid_loader = 'setup_all'
train_dataset: StrokesDataset
valid_dataset: StrokesDataset
batch_step = 'strokes_batch_step'
# Encoder and decoder sizes
enc_hidden_size = 256
dec_hidden_size = 512
# Batch size
batch_size = 100
# Number of features in $z$
d_z = 128
# Number of distributions in the mixture, $M$
n_distributions = 20
# Weight of KL divergence loss, $w_{KL}$
kl_div_loss_weight = 0.5
# Gradient clipping
grad_clip = 1.
# Temperature $\tau$ for sampling
temperature = 0.4
# Filter out stroke sequences longer than $200$
max_seq_length = 200
epochs = 100
def initialize(self):
# Initialize encoder & decoder
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
# Set optimizer, things like type of optimizer and learning rate are configurable
optimizer = OptimizerConfigs()
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
self.optimizer = optimizer
# Create sampler
self.sampler = Sampler(self.encoder, self.decoder)
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
# Create training dataset
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
# Create validation dataset
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
# Create training data loader
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
# Create validation data loader
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
tracker.save()
def sample(self):
# Randomly pick a sample from validation dataset to encoder
@ -633,12 +605,6 @@ class Configs(TrainValidConfigs):
self.sampler.sample(data, self.temperature)
@option(Configs.batch_step)
def strokes_batch_step(c: Configs):
"""Set Strokes Training and Validation module"""
return StrokesBatchStep(c.encoder, c.decoder, c.optimizer, c.kl_div_loss_weight, c.grad_clip)
def main():
configs = Configs()
experiment.create(name="sketch_rnn")
@ -655,8 +621,6 @@ def main():
'inner_iterations': 10
})
configs.initialize()
with experiment.start():
# Run the experiment
configs.run()