mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 02:08:50 +08:00
sketch rnn step
This commit is contained in:
@ -124,8 +124,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
|||||||
tracker.add_global_step(len(data))
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
# Whether to log activations
|
# Whether to log activations
|
||||||
is_log_activations = batch_idx.is_interval(self.log_activations_batches)
|
with self.mode.update(is_log_activations=batch_idx.is_last):
|
||||||
with self.mode.update(is_log_activations=is_log_activations):
|
|
||||||
# Run the model
|
# Run the model
|
||||||
caps, reconstructions, pred = self.model(data)
|
caps, reconstructions, pred = self.model(data)
|
||||||
|
|
||||||
@ -141,7 +140,7 @@ class Configs(MNISTConfigs, SimpleTrainValidConfigs):
|
|||||||
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
# Log parameters and gradients
|
# Log parameters and gradients
|
||||||
if batch_idx.is_interval(self.log_params_updates):
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.model)
|
pytorch_utils.store_model_indicators(self.model)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
|||||||
@ -93,8 +93,6 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
label_smoothing: float = 0.2
|
label_smoothing: float = 0.2
|
||||||
discriminator_k: int = 1
|
discriminator_k: int = 1
|
||||||
|
|
||||||
log_params_updates: int = 2 ** 32 # 0 if not
|
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
self.state_modules = []
|
self.state_modules = []
|
||||||
self.generator = Generator().to(self.device)
|
self.generator = Generator().to(self.device)
|
||||||
@ -136,7 +134,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
self.discriminator_optimizer.zero_grad()
|
self.discriminator_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if batch_idx.is_interval(self.log_params_updates):
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
pytorch_utils.store_model_indicators(self.discriminator, 'discriminator')
|
||||||
self.discriminator_optimizer.step()
|
self.discriminator_optimizer.step()
|
||||||
|
|
||||||
@ -155,7 +153,7 @@ class Configs(MNISTConfigs, TrainValidConfigs):
|
|||||||
if self.mode.is_train:
|
if self.mode.is_train:
|
||||||
self.generator_optimizer.zero_grad()
|
self.generator_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if batch_idx.is_interval(self.log_params_updates):
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
pytorch_utils.store_model_indicators(self.generator, 'generator')
|
||||||
self.generator_optimizer.step()
|
self.generator_optimizer.step()
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ Took help from [PyTorch Sketch RNN](https://github.com/alexis-jacq/Pytorch-Sketc
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -34,13 +34,11 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
|
|
||||||
import einops
|
import einops
|
||||||
from labml import lab, experiment, tracker, monit
|
from labml import lab, experiment, tracker, monit
|
||||||
from labml.configs import option
|
|
||||||
from labml.utils import pytorch as pytorch_utils
|
from labml.utils import pytorch as pytorch_utils
|
||||||
from labml_helpers.device import DeviceConfigs
|
from labml_helpers.device import DeviceConfigs
|
||||||
from labml_helpers.module import Module
|
from labml_helpers.module import Module
|
||||||
from labml_helpers.optimizer import OptimizerConfigs
|
from labml_helpers.optimizer import OptimizerConfigs
|
||||||
from labml_helpers.train_valid import TrainValidConfigs, BatchStepProtocol, hook_model_outputs, \
|
from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
|
||||||
MODE_STATE
|
|
||||||
|
|
||||||
|
|
||||||
class StrokesDataset(Dataset):
|
class StrokesDataset(Dataset):
|
||||||
@ -452,53 +450,104 @@ class Sampler:
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
class StrokesBatchStep(BatchStepProtocol):
|
class Configs(TrainValidConfigs):
|
||||||
"""
|
"""
|
||||||
## Train/Validation modules
|
## Configurations
|
||||||
|
|
||||||
|
These are default configurations which can be later adjusted by passing a `dict`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN,
|
# Device configurations to pick the device to run the experiment
|
||||||
optimizer: Optional[torch.optim.Adam],
|
device: torch.device = DeviceConfigs()
|
||||||
kl_div_loss_weight: float, grad_clip: float):
|
#
|
||||||
self.grad_clip = grad_clip
|
encoder: EncoderRNN
|
||||||
self.kl_div_loss_weight = kl_div_loss_weight
|
decoder: DecoderRNN
|
||||||
self.encoder = encoder
|
optimizer: optim.Adam
|
||||||
self.decoder = decoder
|
sampler: Sampler
|
||||||
|
|
||||||
|
dataset_name: str
|
||||||
|
train_loader: DataLoader
|
||||||
|
valid_loader: DataLoader
|
||||||
|
train_dataset: StrokesDataset
|
||||||
|
valid_dataset: StrokesDataset
|
||||||
|
|
||||||
|
# Encoder and decoder sizes
|
||||||
|
enc_hidden_size = 256
|
||||||
|
dec_hidden_size = 512
|
||||||
|
|
||||||
|
# Batch size
|
||||||
|
batch_size = 100
|
||||||
|
|
||||||
|
# Number of features in $z$
|
||||||
|
d_z = 128
|
||||||
|
# Number of distributions in the mixture, $M$
|
||||||
|
n_distributions = 20
|
||||||
|
|
||||||
|
# Weight of KL divergence loss, $w_{KL}$
|
||||||
|
kl_div_loss_weight = 0.5
|
||||||
|
# Gradient clipping
|
||||||
|
grad_clip = 1.
|
||||||
|
# Temperature $\tau$ for sampling
|
||||||
|
temperature = 0.4
|
||||||
|
|
||||||
|
# Filter out stroke sequences longer than $200$
|
||||||
|
max_seq_length = 200
|
||||||
|
|
||||||
|
epochs = 100
|
||||||
|
|
||||||
|
kl_div_loss = KLDivLoss()
|
||||||
|
reconstruction_loss = ReconstructionLoss()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
# Initialize encoder & decoder
|
||||||
|
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||||
|
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
||||||
|
|
||||||
|
# Set optimizer, things like type of optimizer and learning rate are configurable
|
||||||
|
optimizer = OptimizerConfigs()
|
||||||
|
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.kl_div_loss = KLDivLoss()
|
|
||||||
self.reconstruction_loss = ReconstructionLoss()
|
# Create sampler
|
||||||
|
self.sampler = Sampler(self.encoder, self.decoder)
|
||||||
|
|
||||||
|
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
||||||
|
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
|
||||||
|
# Load the numpy file.
|
||||||
|
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
||||||
|
|
||||||
|
# Create training dataset
|
||||||
|
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
|
||||||
|
# Create validation dataset
|
||||||
|
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
|
||||||
|
|
||||||
|
# Create training data loader
|
||||||
|
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
|
||||||
|
# Create validation data loader
|
||||||
|
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
||||||
|
|
||||||
# Add hooks to monitor layer outputs on Tensorboard
|
# Add hooks to monitor layer outputs on Tensorboard
|
||||||
hook_model_outputs(self.encoder, 'encoder')
|
hook_model_outputs(self.mode, self.encoder, 'encoder')
|
||||||
hook_model_outputs(self.decoder, 'decoder')
|
hook_model_outputs(self.mode, self.decoder, 'decoder')
|
||||||
|
|
||||||
# Configure the tracker to print the total train/validation loss
|
# Configure the tracker to print the total train/validation loss
|
||||||
tracker.set_scalar("loss.total.*", True)
|
tracker.set_scalar("loss.total.*", True)
|
||||||
|
|
||||||
def prepare_for_iteration(self):
|
self.state_modules = []
|
||||||
"""
|
|
||||||
Set models for training/evaluation
|
|
||||||
"""
|
|
||||||
if MODE_STATE.is_train:
|
|
||||||
self.encoder.train()
|
|
||||||
self.decoder.train()
|
|
||||||
else:
|
|
||||||
self.encoder.eval()
|
|
||||||
self.decoder.eval()
|
|
||||||
|
|
||||||
def process(self, batch: any, state: any):
|
def step(self, batch: Any, batch_idx: BatchIndex):
|
||||||
"""
|
self.encoder.train(self.mode.is_train)
|
||||||
Process a batch
|
self.decoder.train(self.mode.is_train)
|
||||||
"""
|
|
||||||
data, mask = batch
|
|
||||||
|
|
||||||
# Get model device
|
|
||||||
device = self.encoder.device
|
|
||||||
# Move `data` and `mask` to device and swap the sequence and batch dimensions.
|
# Move `data` and `mask` to device and swap the sequence and batch dimensions.
|
||||||
# `data` will have shape `[seq_len, batch_size, 5]` and
|
# `data` will have shape `[seq_len, batch_size, 5]` and
|
||||||
# `mask` will have shape `[seq_len, batch_size]`.
|
# `mask` will have shape `[seq_len, batch_size]`.
|
||||||
data = data.to(device).transpose(0, 1)
|
data = batch[0].to(self.device).transpose(0, 1)
|
||||||
mask = mask.to(device).transpose(0, 1)
|
mask = batch[1].to(self.device).transpose(0, 1)
|
||||||
|
|
||||||
|
# Increment step in training mode
|
||||||
|
if self.mode.is_train:
|
||||||
|
tracker.add_global_step(len(data))
|
||||||
|
|
||||||
# Encode the sequence of strokes
|
# Encode the sequence of strokes
|
||||||
with monit.section("encoder"):
|
with monit.section("encoder"):
|
||||||
@ -527,16 +576,16 @@ class StrokesBatchStep(BatchStepProtocol):
|
|||||||
tracker.add("loss.reconstruction.", reconstruction_loss)
|
tracker.add("loss.reconstruction.", reconstruction_loss)
|
||||||
tracker.add("loss.total.", loss)
|
tracker.add("loss.total.", loss)
|
||||||
|
|
||||||
# Run optimizer
|
# Only if we are in training state
|
||||||
with monit.section('optimize'):
|
if self.mode.is_train:
|
||||||
# Only if we are in training state
|
# Run optimizer
|
||||||
if MODE_STATE.is_train:
|
with monit.section('optimize'):
|
||||||
# Set `grad` to zero
|
# Set `grad` to zero
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
# Compute gradients
|
# Compute gradients
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# Log model parameters and gradients
|
# Log model parameters and gradients
|
||||||
if MODE_STATE.is_log_parameters:
|
if batch_idx.is_last:
|
||||||
pytorch_utils.store_model_indicators(self.encoder, 'encoder')
|
pytorch_utils.store_model_indicators(self.encoder, 'encoder')
|
||||||
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
|
pytorch_utils.store_model_indicators(self.decoder, 'decoder')
|
||||||
# Clip gradients
|
# Clip gradients
|
||||||
@ -545,84 +594,7 @@ class StrokesBatchStep(BatchStepProtocol):
|
|||||||
# Optimize
|
# Optimize
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
#
|
tracker.save()
|
||||||
return {'samples': len(data)}, None
|
|
||||||
|
|
||||||
|
|
||||||
class Configs(TrainValidConfigs):
|
|
||||||
"""
|
|
||||||
## Configurations
|
|
||||||
|
|
||||||
These are default configurations which can be later adjusted by passing a `dict`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Device configurations to pick the device to run the experiment
|
|
||||||
device: torch.device = DeviceConfigs()
|
|
||||||
#
|
|
||||||
encoder: EncoderRNN
|
|
||||||
decoder: DecoderRNN
|
|
||||||
optimizer: optim.Adam = 'setup_all'
|
|
||||||
sampler: Sampler
|
|
||||||
|
|
||||||
dataset_name: str
|
|
||||||
train_loader = 'setup_all'
|
|
||||||
valid_loader = 'setup_all'
|
|
||||||
train_dataset: StrokesDataset
|
|
||||||
valid_dataset: StrokesDataset
|
|
||||||
|
|
||||||
batch_step = 'strokes_batch_step'
|
|
||||||
|
|
||||||
# Encoder and decoder sizes
|
|
||||||
enc_hidden_size = 256
|
|
||||||
dec_hidden_size = 512
|
|
||||||
|
|
||||||
# Batch size
|
|
||||||
batch_size = 100
|
|
||||||
|
|
||||||
# Number of features in $z$
|
|
||||||
d_z = 128
|
|
||||||
# Number of distributions in the mixture, $M$
|
|
||||||
n_distributions = 20
|
|
||||||
|
|
||||||
# Weight of KL divergence loss, $w_{KL}$
|
|
||||||
kl_div_loss_weight = 0.5
|
|
||||||
# Gradient clipping
|
|
||||||
grad_clip = 1.
|
|
||||||
# Temperature $\tau$ for sampling
|
|
||||||
temperature = 0.4
|
|
||||||
|
|
||||||
# Filter out stroke sequences longer than $200$
|
|
||||||
max_seq_length = 200
|
|
||||||
|
|
||||||
epochs = 100
|
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
# Initialize encoder & decoder
|
|
||||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
|
||||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
|
||||||
|
|
||||||
# Set optimizer, things like type of optimizer and learning rate are configurable
|
|
||||||
optimizer = OptimizerConfigs()
|
|
||||||
optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
|
||||||
self.optimizer = optimizer
|
|
||||||
|
|
||||||
# Create sampler
|
|
||||||
self.sampler = Sampler(self.encoder, self.decoder)
|
|
||||||
|
|
||||||
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
|
||||||
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
|
|
||||||
# Load the numpy file.
|
|
||||||
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
|
||||||
|
|
||||||
# Create training dataset
|
|
||||||
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
|
|
||||||
# Create validation dataset
|
|
||||||
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
|
|
||||||
|
|
||||||
# Create training data loader
|
|
||||||
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
|
|
||||||
# Create validation data loader
|
|
||||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
|
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
# Randomly pick a sample from validation dataset to encoder
|
# Randomly pick a sample from validation dataset to encoder
|
||||||
@ -633,12 +605,6 @@ class Configs(TrainValidConfigs):
|
|||||||
self.sampler.sample(data, self.temperature)
|
self.sampler.sample(data, self.temperature)
|
||||||
|
|
||||||
|
|
||||||
@option(Configs.batch_step)
|
|
||||||
def strokes_batch_step(c: Configs):
|
|
||||||
"""Set Strokes Training and Validation module"""
|
|
||||||
return StrokesBatchStep(c.encoder, c.decoder, c.optimizer, c.kl_div_loss_weight, c.grad_clip)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
configs = Configs()
|
configs = Configs()
|
||||||
experiment.create(name="sketch_rnn")
|
experiment.create(name="sketch_rnn")
|
||||||
@ -655,8 +621,6 @@ def main():
|
|||||||
'inner_iterations': 10
|
'inner_iterations': 10
|
||||||
})
|
})
|
||||||
|
|
||||||
configs.initialize()
|
|
||||||
|
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
# Run the experiment
|
# Run the experiment
|
||||||
configs.run()
|
configs.run()
|
||||||
|
|||||||
Reference in New Issue
Block a user