diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index 388771a0..7e9af4d2 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -5,7 +5,7 @@ This is an annotated implementation of the paper Download data from [Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset). There is a link to download `npz` files in *Sketch-RNN QuickDraw Dataset* section of the readme. Place the downloaded `npz` file(s) in `data/sketch` folder. -This code is configured to use `bycle` dataset. +This code is configured to use `bicycle` dataset. You can change this in configurations. ### Acknowledgements @@ -14,7 +14,7 @@ Took help from [PyTorch Sketch RNN)(https://github.com/alexis-jacq/Pytorch-Sketc """ import math -from typing import Optional +from typing import Optional, Tuple import einops import numpy as np @@ -41,12 +41,7 @@ class StrokesDataset(Dataset): This class load and pre-process the data. """ - def __init__(self, dataset_name: str, max_seq_length): - # `npz` file path is `data/sketch/[DATASET NAME].npz` - path = lab.get_data_path() / 'sketch' / f'{dataset_name}.npz' - # Load the numpy file. - dataset = np.load(str(path), encoding='latin1', allow_pickle=True) - + def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None): # Filter and convert training sequences to floats. data = [] # `dataset['train']` is a list of numpy arrays of shape [seq_len, 3]. @@ -54,10 +49,10 @@ class StrokesDataset(Dataset): # 3 integers. # First two are the displacements along x and y ($\Delta x$, $\Delta y$) # And the last integer represents the state of the pen - 1 if it's touching - # the paper and 0 otherwise, and 2 if it's end of sequence. + # the paper and 0 otherwise. # # We iterate through each of the sequences - for seq in dataset['train']: + for seq in dataset: # Filter if the length of the the sequence of strokes is within our range if 10 < len(seq) <= max_seq_length: # Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$ @@ -69,17 +64,20 @@ class StrokesDataset(Dataset): # We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation. # This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined. - # Paper notes that the mean is not adjusted for simplicity since the mean is anyway close to $0$. - std = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data])) + # Paper notes that the mean is not adjusted for simplicity, + # since the mean is anyway close to $0$. + if scale is None: + scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data])) + self.scale = scale for s in data: # Adjust by standard deviation - s[:, 0:2] /= std + s[:, 0:2] /= scale # Get the longest sequence length among all sequences longest_seq_len = max([len(seq) for seq in data]) # Initialize PyTorch data array - self.data = torch.zeros(len(data), longest_seq_len, 5, dtype=torch.float) + self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float) # Initialize mask array. Mask has an extra step because the model predicts # end of sequence at the end. self.mask = torch.zeros(len(data), longest_seq_len + 1) @@ -88,21 +86,20 @@ class StrokesDataset(Dataset): seq = torch.from_numpy(seq) len_seq = len(seq) # set x, y - self.data[i, :len_seq, :2] = seq[:, :2] + self.data[i, 1:len_seq + 1, :2] = seq[:, :2] # set pen status - self.data[i, :len_seq - 1, 2] = 1 - seq[:-1, 2] - self.data[i, :len_seq - 1, 3] = seq[:-1, 2] - self.data[i, len_seq - 1:, 4] = 1 - self.mask[i, :len_seq] = 1 + self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2] + self.data[i, 1:len_seq + 1, 3] = seq[:, 2] + self.data[i, len_seq + 1:, 4] = 1 + self.mask[i, :len_seq + 1] = 1 - eos = torch.zeros(len(data), 1, 5) - self.target = torch.cat([self.data, eos], 1) + self.data[:, 0, 2] = 1 def __len__(self): return len(self.data) def __getitem__(self, idx: int): - return self.data[idx], self.target[idx], self.mask[idx] + return self.data[idx], self.mask[idx] class EncoderRNN(Module): @@ -139,22 +136,17 @@ class DecoderRNN(Module): self.n_mixtures = n_mixtures self.q_log_softmax = nn.LogSoftmax(-1) - def __call__(self, inputs, z: torch.Tensor, state=None): + def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): if state is None: hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1) state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) - outputs, (hidden, cell) = self.lstm(inputs, state) - if not self.training: - # We dont have to change shape since hidden has shape [1, batch_size, hidden_size] - # and outputs is of shape [seq_len, batch_size, hidden_size], since we are - # using a single direction one layer lstm - outputs = hidden + _, (hidden, cell) = self.lstm(x, state) - q_logits = self.q_log_softmax(self.q_head(outputs)) + q_logits = self.q_log_softmax(self.q_head(hidden)) pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \ - torch.split(self.mixtures(outputs), self.n_mixtures, 2) + torch.split(self.mixtures(hidden), self.n_mixtures, 2) dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y, torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy)) @@ -185,30 +177,32 @@ class Sampler: def sample(self, data: torch.Tensor, temperature: float): z, _, _ = self.encoder(data) - sos = data.new_tensor([[[0, 0, 1, 0, 0]]]) + sos = data.new_tensor([0, 0, 1, 0, 0]) seq_len = len(data) s = sos - seq_x = [] - seq_y = [] - seq_z = [] + seq = [s] state = None with torch.no_grad(): for i in range(seq_len): - data = torch.cat([s, z.unsqueeze(0)], 2) + data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2) dist, q_logits, state = self.decoder(data, z, state) - s, xy, pen_down, eos = self._sample_step(dist, q_logits, temperature) - seq_x.append(xy[0].item()) - seq_y.append(xy[1].item()) - seq_z.append(pen_down) - if eos: + s = self._sample_step(dist, q_logits, temperature) + seq.append(s) + if s[4] == 1: + print(i) break - x_sample = np.cumsum(seq_x, 0) - y_sample = np.cumsum(seq_y, 0) - z_sample = np.array(seq_z) - sequence = np.stack([x_sample, y_sample, z_sample]).T + seq = torch.stack(seq) - strokes = np.split(sequence, np.where(sequence[:, 2] > 0)[0] + 1) + self.plot(seq) + + @staticmethod + def plot(seq: torch.Tensor): + seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0) + seq[:, 2] = seq[:, 3] == 1 + seq = seq[:, 0:3].detach().cpu().numpy() + + strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1) for s in strokes: plt.plot(s[:, 0], -s[:, 1]) plt.axis('off') @@ -224,10 +218,10 @@ class Sampler: q_idx = q.sample()[0, 0] xy = mix.sample()[0, 0, idx] - next_pos = q_logits.new_zeros(1, 1, 5) - next_pos[0, 0, :2] = xy - next_pos[0, 0, q_idx + 2] = 1 - return next_pos, xy, q_idx == 1, q_idx == 2 + next_pos = q_logits.new_zeros(5) + next_pos[:2] = xy + next_pos[q_idx + 2] = 1 + return next_pos class Configs(TrainValidConfigs): @@ -238,8 +232,10 @@ class Configs(TrainValidConfigs): sampler: Sampler dataset_name: str - dataset: StrokesDataset = 'setup_all' train_loader = 'setup_all' + valid_loader = 'setup_all' + train_dataset: StrokesDataset + valid_dataset: StrokesDataset enc_hidden_size = 256 dec_hidden_size = 512 @@ -256,18 +252,17 @@ class Configs(TrainValidConfigs): temperature = 0.4 max_seq_length = 200 - validator = None - valid_loader = None epochs = 100 def sample(self): - data, *_ = self.dataset[np.random.choice(len(self.dataset))] + data, *_ = self.train_dataset[np.random.choice(len(self.train_dataset))] data = data.unsqueeze(1).to(self.device) self.sampler.sample(data, self.temperature) @setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler, - Configs.dataset, Configs.train_loader]) + Configs.train_dataset, Configs.train_loader, + Configs.valid_dataset, Configs.valid_loader]) def setup_all(self: Configs): self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device) @@ -276,9 +271,16 @@ def setup_all(self: Configs): self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) self.sampler = Sampler(self.encoder, self.decoder) + # `npz` file path is `data/sketch/[DATASET NAME].npz` + path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' + # Load the numpy file. + dataset = np.load(str(path), encoding='latin1', allow_pickle=True) - self.dataset = StrokesDataset(self.dataset_name, self.max_seq_length) - self.train_loader = DataLoader(self.dataset, self.batch_size, shuffle=True) + self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) + self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) + + self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) + self.valid_loader = DataLoader(self.valid_dataset, self.batch_size, shuffle=True) class StrokesBatchStep(BatchStepProtocol): @@ -296,39 +298,34 @@ class StrokesBatchStep(BatchStepProtocol): hook_model_outputs(self.encoder, 'encoder') hook_model_outputs(self.decoder, 'decoder') tracker.set_scalar("loss.*", True) - tracker.set_image("generated", True) def prepare_for_iteration(self): if MODE_STATE.is_train: self.encoder.train() self.decoder.train() else: - self.encoder.eval() - self.decoder.eval() + self.encoder.train() + self.decoder.train() + # self.encoder.eval() + # self.decoder.eval() def process(self, batch: any, state: any): device = self.encoder.device - data, target, mask = batch + data, mask = batch data = data.to(device).transpose(0, 1) - target = target.to(device).transpose(0, 1) mask = mask.to(device).transpose(0, 1) - batch_size = data.shape[1] - seq_len = data.shape[0] with monit.section("encoder"): z, mu, sigma = self.encoder(data) with monit.section("decoder"): - sos = torch.stack([torch.tensor([0, 0, 1, 0, 0])] * batch_size). \ - unsqueeze(0).to(device) - batch_init = torch.cat([sos, data], 0) - z_stack = torch.stack([z] * (seq_len + 1)) - inputs = torch.cat([batch_init, z_stack], 2) - dist, q_logits, _ = self.decoder(inputs, z) + z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1) + inputs = torch.cat([data[:-1], z_stack], 2) + dist, q_logits, _ = self.decoder(inputs, z, None) with monit.section('loss'): kl_loss = self.kl_div_loss(sigma, mu) - reconstruction_loss = self.reconstruction_loss(mask, target, dist, q_logits) + reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits) loss = self.kl_div_loss_weight * kl_loss + reconstruction_loss tracker.add("loss.kl.", kl_loss) @@ -346,8 +343,6 @@ class StrokesBatchStep(BatchStepProtocol): nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) self.optimizer.step() - # tracker.add('generated', generated_images[0:5]) - return {'samples': len(data)}, None @@ -395,7 +390,8 @@ def main(): experiment.configs(configs, { 'optimizer.optimizer': 'Adam', 'optimizer.learning_rate': 1e-3, - 'dataset_name': 'bicycle' + 'dataset_name': 'bicycle', + 'inner_iterations': 10 }, 'run') experiment.start()