🧹 cleanup

This commit is contained in:
Varuna Jayasiri
2020-10-19 14:49:35 +05:30
parent 8b03691d04
commit 04f698e56c

View File

@ -5,7 +5,7 @@ This is an annotated implementation of the paper
Download data from [Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset).
There is a link to download `npz` files in *Sketch-RNN QuickDraw Dataset* section of the readme.
Place the downloaded `npz` file(s) in `data/sketch` folder.
This code is configured to use `bycle` dataset.
This code is configured to use `bicycle` dataset.
You can change this in configurations.
### Acknowledgements
@ -14,7 +14,7 @@ Took help from [PyTorch Sketch RNN)(https://github.com/alexis-jacq/Pytorch-Sketc
"""
import math
from typing import Optional
from typing import Optional, Tuple
import einops
import numpy as np
@ -41,12 +41,7 @@ class StrokesDataset(Dataset):
This class load and pre-process the data.
"""
def __init__(self, dataset_name: str, max_seq_length):
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
# Filter and convert training sequences to floats.
data = []
# `dataset['train']` is a list of numpy arrays of shape [seq_len, 3].
@ -54,10 +49,10 @@ class StrokesDataset(Dataset):
# 3 integers.
# First two are the displacements along x and y ($\Delta x$, $\Delta y$)
# And the last integer represents the state of the pen - 1 if it's touching
# the paper and 0 otherwise, and 2 if it's end of sequence.
# the paper and 0 otherwise.
#
# We iterate through each of the sequences
for seq in dataset['train']:
for seq in dataset:
# Filter if the length of the the sequence of strokes is within our range
if 10 < len(seq) <= max_seq_length:
# Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$
@ -69,17 +64,20 @@ class StrokesDataset(Dataset):
# We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation.
# This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined.
# Paper notes that the mean is not adjusted for simplicity since the mean is anyway close to $0$.
std = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
# Paper notes that the mean is not adjusted for simplicity,
# since the mean is anyway close to $0$.
if scale is None:
scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
self.scale = scale
for s in data:
# Adjust by standard deviation
s[:, 0:2] /= std
s[:, 0:2] /= scale
# Get the longest sequence length among all sequences
longest_seq_len = max([len(seq) for seq in data])
# Initialize PyTorch data array
self.data = torch.zeros(len(data), longest_seq_len, 5, dtype=torch.float)
self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
# Initialize mask array. Mask has an extra step because the model predicts
# end of sequence at the end.
self.mask = torch.zeros(len(data), longest_seq_len + 1)
@ -88,21 +86,20 @@ class StrokesDataset(Dataset):
seq = torch.from_numpy(seq)
len_seq = len(seq)
# set x, y
self.data[i, :len_seq, :2] = seq[:, :2]
self.data[i, 1:len_seq + 1, :2] = seq[:, :2]
# set pen status
self.data[i, :len_seq - 1, 2] = 1 - seq[:-1, 2]
self.data[i, :len_seq - 1, 3] = seq[:-1, 2]
self.data[i, len_seq - 1:, 4] = 1
self.mask[i, :len_seq] = 1
self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
self.data[i, len_seq + 1:, 4] = 1
self.mask[i, :len_seq + 1] = 1
eos = torch.zeros(len(data), 1, 5)
self.target = torch.cat([self.data, eos], 1)
self.data[:, 0, 2] = 1
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int):
return self.data[idx], self.target[idx], self.mask[idx]
return self.data[idx], self.mask[idx]
class EncoderRNN(Module):
@ -139,22 +136,17 @@ class DecoderRNN(Module):
self.n_mixtures = n_mixtures
self.q_log_softmax = nn.LogSoftmax(-1)
def __call__(self, inputs, z: torch.Tensor, state=None):
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
if state is None:
hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())
outputs, (hidden, cell) = self.lstm(inputs, state)
if not self.training:
# We dont have to change shape since hidden has shape [1, batch_size, hidden_size]
# and outputs is of shape [seq_len, batch_size, hidden_size], since we are
# using a single direction one layer lstm
outputs = hidden
_, (hidden, cell) = self.lstm(x, state)
q_logits = self.q_log_softmax(self.q_head(outputs))
q_logits = self.q_log_softmax(self.q_head(hidden))
pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
torch.split(self.mixtures(outputs), self.n_mixtures, 2)
torch.split(self.mixtures(hidden), self.n_mixtures, 2)
dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
@ -185,30 +177,32 @@ class Sampler:
def sample(self, data: torch.Tensor, temperature: float):
z, _, _ = self.encoder(data)
sos = data.new_tensor([[[0, 0, 1, 0, 0]]])
sos = data.new_tensor([0, 0, 1, 0, 0])
seq_len = len(data)
s = sos
seq_x = []
seq_y = []
seq_z = []
seq = [s]
state = None
with torch.no_grad():
for i in range(seq_len):
data = torch.cat([s, z.unsqueeze(0)], 2)
data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
dist, q_logits, state = self.decoder(data, z, state)
s, xy, pen_down, eos = self._sample_step(dist, q_logits, temperature)
seq_x.append(xy[0].item())
seq_y.append(xy[1].item())
seq_z.append(pen_down)
if eos:
s = self._sample_step(dist, q_logits, temperature)
seq.append(s)
if s[4] == 1:
print(i)
break
x_sample = np.cumsum(seq_x, 0)
y_sample = np.cumsum(seq_y, 0)
z_sample = np.array(seq_z)
sequence = np.stack([x_sample, y_sample, z_sample]).T
seq = torch.stack(seq)
strokes = np.split(sequence, np.where(sequence[:, 2] > 0)[0] + 1)
self.plot(seq)
@staticmethod
def plot(seq: torch.Tensor):
seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
seq[:, 2] = seq[:, 3] == 1
seq = seq[:, 0:3].detach().cpu().numpy()
strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
for s in strokes:
plt.plot(s[:, 0], -s[:, 1])
plt.axis('off')
@ -224,10 +218,10 @@ class Sampler:
q_idx = q.sample()[0, 0]
xy = mix.sample()[0, 0, idx]
next_pos = q_logits.new_zeros(1, 1, 5)
next_pos[0, 0, :2] = xy
next_pos[0, 0, q_idx + 2] = 1
return next_pos, xy, q_idx == 1, q_idx == 2
next_pos = q_logits.new_zeros(5)
next_pos[:2] = xy
next_pos[q_idx + 2] = 1
return next_pos
class Configs(TrainValidConfigs):
@ -238,8 +232,10 @@ class Configs(TrainValidConfigs):
sampler: Sampler
dataset_name: str
dataset: StrokesDataset = 'setup_all'
train_loader = 'setup_all'
valid_loader = 'setup_all'
train_dataset: StrokesDataset
valid_dataset: StrokesDataset
enc_hidden_size = 256
dec_hidden_size = 512
@ -256,18 +252,17 @@ class Configs(TrainValidConfigs):
temperature = 0.4
max_seq_length = 200
validator = None
valid_loader = None
epochs = 100
def sample(self):
data, *_ = self.dataset[np.random.choice(len(self.dataset))]
data, *_ = self.train_dataset[np.random.choice(len(self.train_dataset))]
data = data.unsqueeze(1).to(self.device)
self.sampler.sample(data, self.temperature)
@setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler,
Configs.dataset, Configs.train_loader])
Configs.train_dataset, Configs.train_loader,
Configs.valid_dataset, Configs.valid_loader])
def setup_all(self: Configs):
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device)
@ -276,9 +271,16 @@ def setup_all(self: Configs):
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
self.sampler = Sampler(self.encoder, self.decoder)
# `npz` file path is `data/sketch/[DATASET NAME].npz`
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
# Load the numpy file.
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
self.dataset = StrokesDataset(self.dataset_name, self.max_seq_length)
self.train_loader = DataLoader(self.dataset, self.batch_size, shuffle=True)
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size, shuffle=True)
class StrokesBatchStep(BatchStepProtocol):
@ -296,39 +298,34 @@ class StrokesBatchStep(BatchStepProtocol):
hook_model_outputs(self.encoder, 'encoder')
hook_model_outputs(self.decoder, 'decoder')
tracker.set_scalar("loss.*", True)
tracker.set_image("generated", True)
def prepare_for_iteration(self):
if MODE_STATE.is_train:
self.encoder.train()
self.decoder.train()
else:
self.encoder.eval()
self.decoder.eval()
self.encoder.train()
self.decoder.train()
# self.encoder.eval()
# self.decoder.eval()
def process(self, batch: any, state: any):
device = self.encoder.device
data, target, mask = batch
data, mask = batch
data = data.to(device).transpose(0, 1)
target = target.to(device).transpose(0, 1)
mask = mask.to(device).transpose(0, 1)
batch_size = data.shape[1]
seq_len = data.shape[0]
with monit.section("encoder"):
z, mu, sigma = self.encoder(data)
with monit.section("decoder"):
sos = torch.stack([torch.tensor([0, 0, 1, 0, 0])] * batch_size). \
unsqueeze(0).to(device)
batch_init = torch.cat([sos, data], 0)
z_stack = torch.stack([z] * (seq_len + 1))
inputs = torch.cat([batch_init, z_stack], 2)
dist, q_logits, _ = self.decoder(inputs, z)
z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
inputs = torch.cat([data[:-1], z_stack], 2)
dist, q_logits, _ = self.decoder(inputs, z, None)
with monit.section('loss'):
kl_loss = self.kl_div_loss(sigma, mu)
reconstruction_loss = self.reconstruction_loss(mask, target, dist, q_logits)
reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
loss = self.kl_div_loss_weight * kl_loss + reconstruction_loss
tracker.add("loss.kl.", kl_loss)
@ -346,8 +343,6 @@ class StrokesBatchStep(BatchStepProtocol):
nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
self.optimizer.step()
# tracker.add('generated', generated_images[0:5])
return {'samples': len(data)}, None
@ -395,7 +390,8 @@ def main():
experiment.configs(configs, {
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 1e-3,
'dataset_name': 'bicycle'
'dataset_name': 'bicycle',
'inner_iterations': 10
}, 'run')
experiment.start()