mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-30 10:18:50 +08:00
🧹 cleanup
This commit is contained in:
@ -5,7 +5,7 @@ This is an annotated implementation of the paper
|
||||
Download data from [Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset).
|
||||
There is a link to download `npz` files in *Sketch-RNN QuickDraw Dataset* section of the readme.
|
||||
Place the downloaded `npz` file(s) in `data/sketch` folder.
|
||||
This code is configured to use `bycle` dataset.
|
||||
This code is configured to use `bicycle` dataset.
|
||||
You can change this in configurations.
|
||||
|
||||
### Acknowledgements
|
||||
@ -14,7 +14,7 @@ Took help from [PyTorch Sketch RNN)(https://github.com/alexis-jacq/Pytorch-Sketc
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@ -41,12 +41,7 @@ class StrokesDataset(Dataset):
|
||||
This class load and pre-process the data.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_name: str, max_seq_length):
|
||||
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
||||
path = lab.get_data_path() / 'sketch' / f'{dataset_name}.npz'
|
||||
# Load the numpy file.
|
||||
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
||||
|
||||
def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
|
||||
# Filter and convert training sequences to floats.
|
||||
data = []
|
||||
# `dataset['train']` is a list of numpy arrays of shape [seq_len, 3].
|
||||
@ -54,10 +49,10 @@ class StrokesDataset(Dataset):
|
||||
# 3 integers.
|
||||
# First two are the displacements along x and y ($\Delta x$, $\Delta y$)
|
||||
# And the last integer represents the state of the pen - 1 if it's touching
|
||||
# the paper and 0 otherwise, and 2 if it's end of sequence.
|
||||
# the paper and 0 otherwise.
|
||||
#
|
||||
# We iterate through each of the sequences
|
||||
for seq in dataset['train']:
|
||||
for seq in dataset:
|
||||
# Filter if the length of the the sequence of strokes is within our range
|
||||
if 10 < len(seq) <= max_seq_length:
|
||||
# Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$
|
||||
@ -69,17 +64,20 @@ class StrokesDataset(Dataset):
|
||||
|
||||
# We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation.
|
||||
# This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined.
|
||||
# Paper notes that the mean is not adjusted for simplicity since the mean is anyway close to $0$.
|
||||
std = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
|
||||
# Paper notes that the mean is not adjusted for simplicity,
|
||||
# since the mean is anyway close to $0$.
|
||||
if scale is None:
|
||||
scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
|
||||
self.scale = scale
|
||||
for s in data:
|
||||
# Adjust by standard deviation
|
||||
s[:, 0:2] /= std
|
||||
s[:, 0:2] /= scale
|
||||
|
||||
# Get the longest sequence length among all sequences
|
||||
longest_seq_len = max([len(seq) for seq in data])
|
||||
|
||||
# Initialize PyTorch data array
|
||||
self.data = torch.zeros(len(data), longest_seq_len, 5, dtype=torch.float)
|
||||
self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
|
||||
# Initialize mask array. Mask has an extra step because the model predicts
|
||||
# end of sequence at the end.
|
||||
self.mask = torch.zeros(len(data), longest_seq_len + 1)
|
||||
@ -88,21 +86,20 @@ class StrokesDataset(Dataset):
|
||||
seq = torch.from_numpy(seq)
|
||||
len_seq = len(seq)
|
||||
# set x, y
|
||||
self.data[i, :len_seq, :2] = seq[:, :2]
|
||||
self.data[i, 1:len_seq + 1, :2] = seq[:, :2]
|
||||
# set pen status
|
||||
self.data[i, :len_seq - 1, 2] = 1 - seq[:-1, 2]
|
||||
self.data[i, :len_seq - 1, 3] = seq[:-1, 2]
|
||||
self.data[i, len_seq - 1:, 4] = 1
|
||||
self.mask[i, :len_seq] = 1
|
||||
self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
|
||||
self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
|
||||
self.data[i, len_seq + 1:, 4] = 1
|
||||
self.mask[i, :len_seq + 1] = 1
|
||||
|
||||
eos = torch.zeros(len(data), 1, 5)
|
||||
self.target = torch.cat([self.data, eos], 1)
|
||||
self.data[:, 0, 2] = 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self.data[idx], self.target[idx], self.mask[idx]
|
||||
return self.data[idx], self.mask[idx]
|
||||
|
||||
|
||||
class EncoderRNN(Module):
|
||||
@ -139,22 +136,17 @@ class DecoderRNN(Module):
|
||||
self.n_mixtures = n_mixtures
|
||||
self.q_log_softmax = nn.LogSoftmax(-1)
|
||||
|
||||
def __call__(self, inputs, z: torch.Tensor, state=None):
|
||||
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
if state is None:
|
||||
hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
|
||||
state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())
|
||||
|
||||
outputs, (hidden, cell) = self.lstm(inputs, state)
|
||||
if not self.training:
|
||||
# We dont have to change shape since hidden has shape [1, batch_size, hidden_size]
|
||||
# and outputs is of shape [seq_len, batch_size, hidden_size], since we are
|
||||
# using a single direction one layer lstm
|
||||
outputs = hidden
|
||||
_, (hidden, cell) = self.lstm(x, state)
|
||||
|
||||
q_logits = self.q_log_softmax(self.q_head(outputs))
|
||||
q_logits = self.q_log_softmax(self.q_head(hidden))
|
||||
|
||||
pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
|
||||
torch.split(self.mixtures(outputs), self.n_mixtures, 2)
|
||||
torch.split(self.mixtures(hidden), self.n_mixtures, 2)
|
||||
|
||||
dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
|
||||
torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
|
||||
@ -185,30 +177,32 @@ class Sampler:
|
||||
|
||||
def sample(self, data: torch.Tensor, temperature: float):
|
||||
z, _, _ = self.encoder(data)
|
||||
sos = data.new_tensor([[[0, 0, 1, 0, 0]]])
|
||||
sos = data.new_tensor([0, 0, 1, 0, 0])
|
||||
seq_len = len(data)
|
||||
s = sos
|
||||
seq_x = []
|
||||
seq_y = []
|
||||
seq_z = []
|
||||
seq = [s]
|
||||
state = None
|
||||
with torch.no_grad():
|
||||
for i in range(seq_len):
|
||||
data = torch.cat([s, z.unsqueeze(0)], 2)
|
||||
data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
|
||||
dist, q_logits, state = self.decoder(data, z, state)
|
||||
s, xy, pen_down, eos = self._sample_step(dist, q_logits, temperature)
|
||||
seq_x.append(xy[0].item())
|
||||
seq_y.append(xy[1].item())
|
||||
seq_z.append(pen_down)
|
||||
if eos:
|
||||
s = self._sample_step(dist, q_logits, temperature)
|
||||
seq.append(s)
|
||||
if s[4] == 1:
|
||||
print(i)
|
||||
break
|
||||
|
||||
x_sample = np.cumsum(seq_x, 0)
|
||||
y_sample = np.cumsum(seq_y, 0)
|
||||
z_sample = np.array(seq_z)
|
||||
sequence = np.stack([x_sample, y_sample, z_sample]).T
|
||||
seq = torch.stack(seq)
|
||||
|
||||
strokes = np.split(sequence, np.where(sequence[:, 2] > 0)[0] + 1)
|
||||
self.plot(seq)
|
||||
|
||||
@staticmethod
|
||||
def plot(seq: torch.Tensor):
|
||||
seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
|
||||
seq[:, 2] = seq[:, 3] == 1
|
||||
seq = seq[:, 0:3].detach().cpu().numpy()
|
||||
|
||||
strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
|
||||
for s in strokes:
|
||||
plt.plot(s[:, 0], -s[:, 1])
|
||||
plt.axis('off')
|
||||
@ -224,10 +218,10 @@ class Sampler:
|
||||
q_idx = q.sample()[0, 0]
|
||||
|
||||
xy = mix.sample()[0, 0, idx]
|
||||
next_pos = q_logits.new_zeros(1, 1, 5)
|
||||
next_pos[0, 0, :2] = xy
|
||||
next_pos[0, 0, q_idx + 2] = 1
|
||||
return next_pos, xy, q_idx == 1, q_idx == 2
|
||||
next_pos = q_logits.new_zeros(5)
|
||||
next_pos[:2] = xy
|
||||
next_pos[q_idx + 2] = 1
|
||||
return next_pos
|
||||
|
||||
|
||||
class Configs(TrainValidConfigs):
|
||||
@ -238,8 +232,10 @@ class Configs(TrainValidConfigs):
|
||||
sampler: Sampler
|
||||
|
||||
dataset_name: str
|
||||
dataset: StrokesDataset = 'setup_all'
|
||||
train_loader = 'setup_all'
|
||||
valid_loader = 'setup_all'
|
||||
train_dataset: StrokesDataset
|
||||
valid_dataset: StrokesDataset
|
||||
|
||||
enc_hidden_size = 256
|
||||
dec_hidden_size = 512
|
||||
@ -256,18 +252,17 @@ class Configs(TrainValidConfigs):
|
||||
temperature = 0.4
|
||||
max_seq_length = 200
|
||||
|
||||
validator = None
|
||||
valid_loader = None
|
||||
epochs = 100
|
||||
|
||||
def sample(self):
|
||||
data, *_ = self.dataset[np.random.choice(len(self.dataset))]
|
||||
data, *_ = self.train_dataset[np.random.choice(len(self.train_dataset))]
|
||||
data = data.unsqueeze(1).to(self.device)
|
||||
self.sampler.sample(data, self.temperature)
|
||||
|
||||
|
||||
@setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler,
|
||||
Configs.dataset, Configs.train_loader])
|
||||
Configs.train_dataset, Configs.train_loader,
|
||||
Configs.valid_dataset, Configs.valid_loader])
|
||||
def setup_all(self: Configs):
|
||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device)
|
||||
@ -276,9 +271,16 @@ def setup_all(self: Configs):
|
||||
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
|
||||
self.sampler = Sampler(self.encoder, self.decoder)
|
||||
# `npz` file path is `data/sketch/[DATASET NAME].npz`
|
||||
path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
|
||||
# Load the numpy file.
|
||||
dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
|
||||
|
||||
self.dataset = StrokesDataset(self.dataset_name, self.max_seq_length)
|
||||
self.train_loader = DataLoader(self.dataset, self.batch_size, shuffle=True)
|
||||
self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
|
||||
self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
|
||||
self.valid_loader = DataLoader(self.valid_dataset, self.batch_size, shuffle=True)
|
||||
|
||||
|
||||
class StrokesBatchStep(BatchStepProtocol):
|
||||
@ -296,39 +298,34 @@ class StrokesBatchStep(BatchStepProtocol):
|
||||
hook_model_outputs(self.encoder, 'encoder')
|
||||
hook_model_outputs(self.decoder, 'decoder')
|
||||
tracker.set_scalar("loss.*", True)
|
||||
tracker.set_image("generated", True)
|
||||
|
||||
def prepare_for_iteration(self):
|
||||
if MODE_STATE.is_train:
|
||||
self.encoder.train()
|
||||
self.decoder.train()
|
||||
else:
|
||||
self.encoder.eval()
|
||||
self.decoder.eval()
|
||||
self.encoder.train()
|
||||
self.decoder.train()
|
||||
# self.encoder.eval()
|
||||
# self.decoder.eval()
|
||||
|
||||
def process(self, batch: any, state: any):
|
||||
device = self.encoder.device
|
||||
data, target, mask = batch
|
||||
data, mask = batch
|
||||
data = data.to(device).transpose(0, 1)
|
||||
target = target.to(device).transpose(0, 1)
|
||||
mask = mask.to(device).transpose(0, 1)
|
||||
batch_size = data.shape[1]
|
||||
seq_len = data.shape[0]
|
||||
|
||||
with monit.section("encoder"):
|
||||
z, mu, sigma = self.encoder(data)
|
||||
|
||||
with monit.section("decoder"):
|
||||
sos = torch.stack([torch.tensor([0, 0, 1, 0, 0])] * batch_size). \
|
||||
unsqueeze(0).to(device)
|
||||
batch_init = torch.cat([sos, data], 0)
|
||||
z_stack = torch.stack([z] * (seq_len + 1))
|
||||
inputs = torch.cat([batch_init, z_stack], 2)
|
||||
dist, q_logits, _ = self.decoder(inputs, z)
|
||||
z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
|
||||
inputs = torch.cat([data[:-1], z_stack], 2)
|
||||
dist, q_logits, _ = self.decoder(inputs, z, None)
|
||||
|
||||
with monit.section('loss'):
|
||||
kl_loss = self.kl_div_loss(sigma, mu)
|
||||
reconstruction_loss = self.reconstruction_loss(mask, target, dist, q_logits)
|
||||
reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
|
||||
loss = self.kl_div_loss_weight * kl_loss + reconstruction_loss
|
||||
|
||||
tracker.add("loss.kl.", kl_loss)
|
||||
@ -346,8 +343,6 @@ class StrokesBatchStep(BatchStepProtocol):
|
||||
nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
# tracker.add('generated', generated_images[0:5])
|
||||
|
||||
return {'samples': len(data)}, None
|
||||
|
||||
|
||||
@ -395,7 +390,8 @@ def main():
|
||||
experiment.configs(configs, {
|
||||
'optimizer.optimizer': 'Adam',
|
||||
'optimizer.learning_rate': 1e-3,
|
||||
'dataset_name': 'bicycle'
|
||||
'dataset_name': 'bicycle',
|
||||
'inner_iterations': 10
|
||||
}, 'run')
|
||||
experiment.start()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user