mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
sketch rnn step
This commit is contained in:
@ -124,8 +124,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Whether to log activations
|
||||
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
|
||||
with self.mode.update(is_log_activations=is_log_activations):
|
||||
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||
# Run the model
|
||||
caps, reconstructions, pred = self.model(data)
|
||||
|
||||
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
||||
|
||||
self.optimizer.step()
|
||||
# Log parameters and gradients
|
||||
if batch_idx.is_interval(self.log_params_updates):
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.model)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
|
||||
@ -93,8 +93,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
label_smoothing: float = 0.2
|
||||
discriminator_k: int = 1
|
||||
|
||||
log_params_updates: int = 2 ** 32 # 0 if not
|
||||
|
||||
def init(self):
|
||||
self.state_modules = []
|
||||
self.generator = Generator().to(self.device)
|
||||
@ -136,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
self.discriminator_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if batch_idx.is_interval(self.log_params_updates):
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
||||
self.discriminator_optimizer.step()
|
||||
|
||||
@ -155,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
||||
if self.mode.is_train:
|
||||
self.generator_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if batch_idx.is_interval(self.log_params_updates):
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
||||
self.generator_optimizer.step()
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ Took help from [PyTorch Sketch RNN](https://github.com/alexis-jacq/Pytorch-Sketc
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -34,13 +34,11 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import einops
|
||||
from labml import lab, experiment, tracker, monit
|
||||
from labml.configs import option
|
||||
from labml.utils import pytorch as pytorch_utils
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_helpers.module import Module
|
||||
from labml_helpers.optimizer import OptimizerConfigs
|
||||
from labml_helpers.train_valid import TrainValidConfigs, BatchStepProtocol, hook_model_outputs, \
|
||||
MODE_STATE
|
||||
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||
|
||||
|
||||
class StrokesDataset(Dataset):
|
||||
@ -452,53 +450,104 @@ class Sampler:
|
||||
plt.show()
|
||||
|
||||
|
||||
class StrokesBatchStep(BatchStepProtocol):
|
||||
class Configs(TrainValidConfigs):
|
||||
"""
|
||||
## Train/Validation modules
|
||||
## Configurations
|
||||
|
||||
These are default configurations which can be later adjusted by passing a `dict`.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN,
|
||||
optimizer: Optional[torch.optim.Adam],
|
||||
kl_div_loss_weight: float, grad_clip: float):
|
||||
self.grad_clip = grad_clip
|
||||
self.kl_div_loss_weight = kl_div_loss_weight
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
# Device configurations to pick the device to run the experiment
|
||||
device: torch.device = DeviceConfigs()
|
||||
#
|
||||
encoder: EncoderRNN
|
||||
decoder: DecoderRNN
|
||||
optimizer: optim.Adam
|
||||
sampler: Sampler
|
||||
|
||||
dataset_name: str
|
||||
train_loader: DataLoader
|
||||
valid_loader: DataLoader
|
||||
train_dataset: StrokesDataset
|
||||
valid_dataset: StrokesDataset
|
||||
|
||||
# Encoder and decoder sizes
|
||||
enc_hidden_size = 256
|
||||
dec_hidden_size = 512
|
||||
|
||||
# Batch size
|
||||
batch_size = 100
|
||||
|
||||
# Number of features in $z$
|
||||
d_z = 128
|
||||
# Number of distributions in the mixture, $M$
|
||||
n_distributions = 20
|
||||
|
||||
# Weight of KL divergence loss, $w_{KL}$
|
||||
kl_div_loss_weight = 0.5
|
||||
# Gradient clipping
|
||||
grad_clip = 1.
|
||||
# Temperature $\tau$ for sampling
|
||||
temperature = 0.4
|
||||
|
||||
# Filter out stroke sequences longer than $200$
|
||||
max_seq_length = 200
|
||||
|
||||
epochs = 100
|
||||
|
||||
kl_div_loss = KLDivLoss()
|
||||
reconstruction_loss = ReconstructionLoss()
|
||||
|
||||
def init(self):
|
||||
# Initialize encoder & decoder
|
||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
||||
|
||||
# Set optimizer, things like type of optimizer and learning rate are configurable
|
||||
optimizer = OptimizerConfigs()
|
||||
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
self.optimizer = optimizer
|
||||
self.kl_div_loss = KLDivLoss()
|
||||
self.reconstruction_loss = ReconstructionLoss()
|
||||
|
||||
# Create sampler
|
||||
self.sampler = Sampler(self.encoder, self.decoder)
|
||||
|
||||
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
||||
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
|
||||
# Load the numpy file.
|
||||
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
||||
|
||||
# Create training dataset
|
||||
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
|
||||
# Create validation dataset
|
||||
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
|
||||
|
||||
# Create training data loader
|
||||
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
|
||||
# Create validation data loader
|
||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||
|
||||
# Add hooks to monitor layer outputs on Tensorboard
|
||||
hook_model_outputs(self.encoder, 'encoder')
|
||||
hook_model_outputs(self.decoder, 'decoder')
|
||||
hook_model_outputs(self.mode, self.encoder, 'encoder')
|
||||
hook_model_outputs(self.mode, self.decoder, 'decoder')
|
||||
|
||||
# Configure the tracker to print the total train/validation loss
|
||||
tracker.set_scalar("loss.total.*", True)
|
||||
|
||||
def prepare_for_iteration(self):
|
||||
"""
|
||||
Set models for training/evaluation
|
||||
"""
|
||||
if MODE_STATE.is_train:
|
||||
self.encoder.train()
|
||||
self.decoder.train()
|
||||
else:
|
||||
self.encoder.eval()
|
||||
self.decoder.eval()
|
||||
self.state_modules = []
|
||||
|
||||
def process(self, batch: any, state: any):
|
||||
"""
|
||||
Process a batch
|
||||
"""
|
||||
data, mask = batch
|
||||
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||
self.encoder.train(self.mode.is_train)
|
||||
self.decoder.train(self.mode.is_train)
|
||||
|
||||
# Get model device
|
||||
device = self.encoder.device
|
||||
# Move `data` and `mask` to device and swap the sequence and batch dimensions.
|
||||
# `data` will have shape `[seq_len, batch_size, 5]` and
|
||||
# `mask` will have shape `[seq_len, batch_size]`.
|
||||
data = data.to(device).transpose(0, 1)
|
||||
mask = mask.to(device).transpose(0, 1)
|
||||
data = batch[0].to(self.device).transpose(0, 1)
|
||||
mask = batch[1].to(self.device).transpose(0, 1)
|
||||
|
||||
# Increment step in training mode
|
||||
if self.mode.is_train:
|
||||
tracker.add_global_step(len(data))
|
||||
|
||||
# Encode the sequence of strokes
|
||||
with monit.section("encoder"):
|
||||
@ -527,16 +576,16 @@ class StrokesBatchStep(BatchStepProtocol):
|
||||
tracker.add("loss.reconstruction.", reconstruction_loss)
|
||||
tracker.add("loss.total.", loss)
|
||||
|
||||
# Only if we are in training state
|
||||
if self.mode.is_train:
|
||||
# Run optimizer
|
||||
with monit.section('optimize'):
|
||||
# Only if we are in training state
|
||||
if MODE_STATE.is_train:
|
||||
# Set `grad` to zero
|
||||
self.optimizer.zero_grad()
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
# Log model parameters and gradients
|
||||
if MODE_STATE.is_log_parameters:
|
||||
if batch_idx.is_last:
|
||||
pytorch_utils.store_model_indicators(self.encoder, 'encoder')
|
||||
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
|
||||
# Clip gradients
|
||||
@ -545,84 +594,7 @@ class StrokesBatchStep(BatchStepProtocol):
|
||||
# Optimize
|
||||
self.optimizer.step()
|
||||
|
||||
#
|
||||
return {'samples': len(data)}, None
|
||||
|
||||
|
||||
class Configs(TrainValidConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
|
||||
These are default configurations which can be later adjusted by passing a `dict`.
|
||||
"""
|
||||
|
||||
# Device configurations to pick the device to run the experiment
|
||||
device: torch.device = DeviceConfigs()
|
||||
#
|
||||
encoder: EncoderRNN
|
||||
decoder: DecoderRNN
|
||||
optimizer: optim.Adam = 'setup_all'
|
||||
sampler: Sampler
|
||||
|
||||
dataset_name: str
|
||||
train_loader = 'setup_all'
|
||||
valid_loader = 'setup_all'
|
||||
train_dataset: StrokesDataset
|
||||
valid_dataset: StrokesDataset
|
||||
|
||||
batch_step = 'strokes_batch_step'
|
||||
|
||||
# Encoder and decoder sizes
|
||||
enc_hidden_size = 256
|
||||
dec_hidden_size = 512
|
||||
|
||||
# Batch size
|
||||
batch_size = 100
|
||||
|
||||
# Number of features in $z$
|
||||
d_z = 128
|
||||
# Number of distributions in the mixture, $M$
|
||||
n_distributions = 20
|
||||
|
||||
# Weight of KL divergence loss, $w_{KL}$
|
||||
kl_div_loss_weight = 0.5
|
||||
# Gradient clipping
|
||||
grad_clip = 1.
|
||||
# Temperature $\tau$ for sampling
|
||||
temperature = 0.4
|
||||
|
||||
# Filter out stroke sequences longer than $200$
|
||||
max_seq_length = 200
|
||||
|
||||
epochs = 100
|
||||
|
||||
def initialize(self):
|
||||
# Initialize encoder & decoder
|
||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
||||
|
||||
# Set optimizer, things like type of optimizer and learning rate are configurable
|
||||
optimizer = OptimizerConfigs()
|
||||
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
self.optimizer = optimizer
|
||||
|
||||
# Create sampler
|
||||
self.sampler = Sampler(self.encoder, self.decoder)
|
||||
|
||||
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
||||
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
|
||||
# Load the numpy file.
|
||||
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
||||
|
||||
# Create training dataset
|
||||
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
|
||||
# Create validation dataset
|
||||
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
|
||||
|
||||
# Create training data loader
|
||||
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
|
||||
# Create validation data loader
|
||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||
tracker.save()
|
||||
|
||||
def sample(self):
|
||||
# Randomly pick a sample from validation dataset to encoder
|
||||
@ -633,12 +605,6 @@ class Configs(TrainValidConfigs):
|
||||
self.sampler.sample(data, self.temperature)
|
||||
|
||||
|
||||
@option(Configs.batch_step)
|
||||
def strokes_batch_step(c: Configs):
|
||||
"""Set Strokes Training and Validation module"""
|
||||
return StrokesBatchStep(c.encoder, c.decoder, c.optimizer, c.kl_div_loss_weight, c.grad_clip)
|
||||
|
||||
|
||||
def main():
|
||||
configs = Configs()
|
||||
experiment.create(name="sketch_rnn")
|
||||
@ -655,8 +621,6 @@ def main():
|
||||
'inner_iterations': 10
|
||||
})
|
||||
|
||||
configs.initialize()
|
||||
|
||||
with experiment.start():
|
||||
# Run the experiment
|
||||
configs.run()
|
||||
|
||||
Reference in New Issue
Block a user