This is an annotated implementation of the paper A Neural Representation of Sketch Drawings.
Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian’s.
Download data from Quick, Draw! Dataset.
There is a link to download npz
files in Sketch-RNN QuickDraw Dataset section of the readme.
Place the downloaded npz
file(s) in data/sketch
folder.
This code is configured to use bicycle
dataset.
You can change this in configurations.
Took help from PyTorch Sketch RNN project by Alexis David Jacq
32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
50class StrokesDataset(Dataset):
dataset
is a list of numpy arrays of shape [seq_len, 3].
It is a sequence of strokes, and each stroke is represented by
3 integers.
First two are the displacements along x and y ($\Delta x$, $\Delta y$)
And the last integer represents the state of the pen - $1$ if it’s touching
the paper and $0$ otherwise.
57 def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67 data = []
We iterate through each of the sequences and filter
69 for seq in dataset:
Filter if the length of the the sequence of strokes is within our range
71 if 10 < len(seq) <= max_seq_length:
Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$
73 seq = np.minimum(seq, 1000)
74 seq = np.maximum(seq, -1000)
Convert to a floating point array and add to data
76 seq = np.array(seq, dtype=np.float32)
77 data.append(seq)
We then calculate the scaling factor which is the standard deviation of ($\Delta x$, $\Delta y$) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to $0$.
83 if scale is None:
84 scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85 self.scale = scale
Get the longest sequence length among all sequences
88 longest_seq_len = max([len(seq) for seq in data])
We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector $(\Delta x, \Delta y, p_1, p_2, p_3)$. Only one of $p_1, p_2, p_3$ is $1$ and the others are $0$. They represent pen down, pen up and end-of-sequence in that order. $p_1$ is $1$ if the pen touches the paper in the next step. $p_2$ is $1$ if the pen doesn’t touch the paper in the next step. $p_3$ is $1$ if it is the end of the drawing.
98 self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
The mask array is needs only one extra-step since it is for the outputs of the
decoder, which takes in data[:-1]
and predicts next step.
101 self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103 for i, seq in enumerate(data):
104 seq = torch.from_numpy(seq)
105 len_seq = len(seq)
Scale and set $\Delta x, \Delta y$
107 self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
$p_1$
109 self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
$p_2$
111 self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
$p_3$
113 self.data[i, len_seq + 1:, 4] = 1
Mask is on until end of sequence
115 self.mask[i, :len_seq + 1] = 1
Start-of-sequence is $(0, 0, 1, 0, 0)
118 self.data[:, 0, 2] = 1
Size of the dataset
120 def __len__(self):
122 return len(self.data)
Get a sample
124 def __getitem__(self, idx: int):
126 return self.data[idx], self.mask[idx]
The mixture is represented by $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$. This class adjust temperatures and creates the categorical and gaussian distributions from the parameters.
129class BivariateGaussianMixture:
139 def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141 self.pi_logits = pi_logits
142 self.mu_x = mu_x
143 self.mu_y = mu_y
144 self.sigma_x = sigma_x
145 self.sigma_y = sigma_y
146 self.rho_xy = rho_xy
Number of distributions in the mixture, $M$
148 @property
149 def n_distributions(self):
151 return self.pi_logits.shape[-1]
Adjust by temperature $\tau$
153 def set_temperature(self, temperature: float):
158 self.pi_logits /= temperature
160 self.sigma_x *= math.sqrt(temperature)
162 self.sigma_y *= math.sqrt(temperature)
164 def get_distribution(self):
Clamp $\sigma_x$, $\sigma_y$ and $\rho_{xy}$ to avoid getting NaN
s
166 sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167 sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168 rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)
Get means
171 mean = torch.stack([self.mu_x, self.mu_y], -1)
Get covariance matrix
173 cov = torch.stack([
174 sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175 rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176 ], -1)
177 cov = cov.view(*sigma_y.shape, 2, 2)
Create bi-variate normal distribution.
📝 It would be efficient to scale_tril
matrix as [[a, 0], [b, c]]
where
.
But for simplicity we use co-variance matrix.
This is a good resource
if you want to read up more about bi-variate distributions, their co-variance matrix,
and probability density function.
188 multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
Create categorical distribution $\Pi$ from logits
191 cat_dist = torch.distributions.Categorical(logits=self.pi_logits)
194 return cat_dist, multi_dist
197class EncoderRNN(Module):
204 def __init__(self, d_z: int, enc_hidden_size: int):
205 super().__init__()
Create a bidirectional LSTM takes a sequence of $(\Delta x, \Delta y, p_1, p_2, p_3)$ as input.
208 self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
Head to get $\mu$
210 self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
Head to get $\hat{\sigma}$
212 self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214 def __call__(self, inputs: torch.Tensor, state=None):
The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and and first token in the reverse direction. Which is what we want.
222 _, (hidden, cell) = self.lstm(inputs.float(), state)
The state has shape [2, batch_size, hidden_size]
where the first dimension is the direction.
We rearrange it to get $h = [h_{\rightarrow}; h_{\leftarrow}]$
226 hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
$\mu$
229 mu = self.mu_head(hidden)
$\hat{\sigma}$
231 sigma_hat = self.sigma_head(hidden)
$\sigma = \exp(\frac{\hat{\sigma}}{2})$
233 sigma = torch.exp(sigma_hat / 2.)
Sample $z = \mu + \sigma \cdot \mathcal{N}(0, I)$
236 z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
239 return z, mu, sigma_hat
242class DecoderRNN(Module):
249 def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
250 super().__init__()
LSTM takes $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$ as input
252 self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
Initial state of the LSTM is $[h_0; c_0] = \tanh(W_{z}z + b_z)$.
init_state
is the linear transformation for this
256 self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
This layer produces outputs for each of of the n_distributions
.
Each distribution needs six parameters
$(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$
261 self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
This head is for the logits $(\hat{q_1}, \hat{q_2}, \hat{q_3})$
264 self.q_head = nn.Linear(dec_hidden_size, 3)
This is to calculate $\log(q_k)$ where
267 self.q_log_softmax = nn.LogSoftmax(-1)
These parameters are stored for future reference
270 self.n_distributions = n_distributions
271 self.dec_hidden_size = dec_hidden_size
273 def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
Calculate the initial state
275 if state is None:
$[h_0; c_0] = \tanh(W_{z}z + b_z)$
277 h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
h
and c
have shapes [batch_size, lstm_size]
. We want to make them
to shape [1, batch_size, lstm_size]
because that’s the shape used in LSTM.
280 state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())
Run the LSTM
283 outputs, state = self.lstm(x, state)
Get $\log(q)$
286 q_logits = self.q_log_softmax(self.q_head(outputs))
Get $(\hat{\Pi_i}, \mu_{x,i}, \mu_{y,i}, \hat{\sigma_{x,i}},
\hat{\sigma_{y,i}} \hat{\rho_{xy,i}})$.
torch.split
splits the output into 6 tensors of size self.n_distribution
across dimension 2
.
292 pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
293 torch.split(self.mixtures(outputs), self.n_distributions, 2)
Create a bi-variate gaussian mixture $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$ where and
$\Pi$ is the categorical probabilities of choosing the distribution out of the mixture $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.
306 dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
307 torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
310 return dist, q_logits, state
313class ReconstructionLoss(Module):
318 def __call__(self, mask: torch.Tensor, target: torch.Tensor,
319 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
321 pi, mix = dist.get_distribution()
target
has shape [seq_len, batch_size, 5]
where the last dimension is the features
$(\Delta x, \Delta y, p_1, p_2, p_3)$.
We want to get $\Delta x, \Delta$ and get the probabilities from each of the distributions
in the mixture $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.
xy
will have shape [seq_len, batch_size, n_distributions, 2]
328 xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
Calculate the probabilities
334 probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
Although probs
has $N_{max}$ (longest_seq_len
) elements the sum is only taken
upto $N_s$ because the rest are masked out.
It might feel like we should be taking the sum and dividing by $N_s$ and not $N_{max}$, but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction $p(\Delta x, \Delta y)$ when we divide by $N_{max}$
343 loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
346 loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
349 return loss_stroke + loss_pen
This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$
352class KLDivLoss(Module):
359 def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
361 return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
364class Sampler:
371 def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
372 self.decoder = decoder
373 self.encoder = encoder
375 def sample(self, data: torch.Tensor, temperature: float):
$N_{max}$
377 longest_seq_len = len(data)
Get $z$ from the encoder
380 z, _, _ = self.encoder(data)
Start-of-sequence stroke is $(0, 0, 1, 0, 0)$
383 s = data.new_tensor([0, 0, 1, 0, 0])
384 seq = [s]
Initial decoder is None
.
The decoder will initialize it to $[h_0; c_0] = \tanh(W_{z}z + b_z)$
387 state = None
We don’t need gradients
390 with torch.no_grad():
Sample $N_{max}$ strokes
392 for i in range(longest_seq_len):
$[(\Delta x, \Delta y, p_1, p_2, p_3); z] is the input to the decoder$
394 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
Get $\Pi$, $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$, $q$ and the next state from the decoder
397 dist, q_logits, state = self.decoder(data, z, state)
Sample a stroke
399 s = self._sample_step(dist, q_logits, temperature)
Add the new stroke to the sequence of strokes
401 seq.append(s)
Stop sampling if $p_3 = 1$. This indicates that sketching has stopped
403 if s[4] == 1:
404 break
Create a PyTorch tensor of the sequence of strokes
407 seq = torch.stack(seq)
Plot the sequence of strokes
410 self.plot(seq)
412 @staticmethod
413 def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):
Set temperature $\tau$ for sampling. This is implemented in class BivariateGaussianMixture
.
415 dist.set_temperature(temperature)
Get temperature adjusted $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
417 pi, mix = dist.get_distribution()
Sample from $\Pi$ the index of the distribution to use from the mixture
419 idx = pi.sample()[0, 0]
Create categorical distribution $q$ with log-probabilities q_logits
or $\hat{q}$
422 q = torch.distributions.Categorical(logits=q_logits / temperature)
Sample from $q$
424 q_idx = q.sample()[0, 0]
Sample from the normal distributions in the mixture and pick the one indexed by idx
427 xy = mix.sample()[0, 0, idx]
Create an empty stroke $(\Delta x, \Delta y, q_1, q_2, q_3)$
430 stroke = q_logits.new_zeros(5)
Set $\Delta x, \Delta y$
432 stroke[:2] = xy
Set $q_1, q_2, q_3$
434 stroke[q_idx + 2] = 1
436 return stroke
438 @staticmethod
439 def plot(seq: torch.Tensor):
Take the cumulative sums of $(\Delta x, \Delta y)$ to get $$x, y)$
441 seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
Create a new numpy array of the form $(x, y, q_2)$
443 seq[:, 2] = seq[:, 3]
444 seq = seq[:, 0:3].detach().cpu().numpy()
Split the array at points where $q_2$ is $1$. That is split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.
449 strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
Plot each sequence of strokes
451 for s in strokes:
452 plt.plot(s[:, 0], -s[:, 1])
Don’t show axes
454 plt.axis('off')
Show the plot
456 plt.show()
459class Configs(TrainValidConfigs):
Device configurations to pick the device to run the experiment
467 device: torch.device = DeviceConfigs()
469 encoder: EncoderRNN
470 decoder: DecoderRNN
471 optimizer: optim.Adam
472 sampler: Sampler
473
474 dataset_name: str
475 train_loader: DataLoader
476 valid_loader: DataLoader
477 train_dataset: StrokesDataset
478 valid_dataset: StrokesDataset
Encoder and decoder sizes
481 enc_hidden_size = 256
482 dec_hidden_size = 512
Batch size
485 batch_size = 100
Number of features in $z$
488 d_z = 128
Number of distributions in the mixture, $M$
490 n_distributions = 20
Weight of KL divergence loss, $w_{KL}$
493 kl_div_loss_weight = 0.5
Gradient clipping
495 grad_clip = 1.
Temperature $\tau$ for sampling
497 temperature = 0.4
Filter out stroke sequences longer than $200$
500 max_seq_length = 200
501
502 epochs = 100
503
504 kl_div_loss = KLDivLoss()
505 reconstruction_loss = ReconstructionLoss()
507 def init(self):
Initialize encoder & decoder
509 self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
510 self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
Set optimizer, things like type of optimizer and learning rate are configurable
513 optimizer = OptimizerConfigs()
514 optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
515 self.optimizer = optimizer
Create sampler
518 self.sampler = Sampler(self.encoder, self.decoder)
npz
file path is data/sketch/[DATASET NAME].npz
521 path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
Load the numpy file.
523 dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
Create training dataset
526 self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
Create validation dataset
528 self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
Create training data loader
531 self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
Create validation data loader
533 self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
Add hooks to monitor layer outputs on Tensorboard
536 hook_model_outputs(self.mode, self.encoder, 'encoder')
537 hook_model_outputs(self.mode, self.decoder, 'decoder')
Configure the tracker to print the total train/validation loss
540 tracker.set_scalar("loss.total.*", True)
541
542 self.state_modules = []
544 def step(self, batch: Any, batch_idx: BatchIndex):
545 self.encoder.train(self.mode.is_train)
546 self.decoder.train(self.mode.is_train)
Move data
and mask
to device and swap the sequence and batch dimensions.
data
will have shape [seq_len, batch_size, 5]
and
mask
will have shape [seq_len, batch_size]
.
551 data = batch[0].to(self.device).transpose(0, 1)
552 mask = batch[1].to(self.device).transpose(0, 1)
Increment step in training mode
555 if self.mode.is_train:
556 tracker.add_global_step(len(data))
Encode the sequence of strokes
559 with monit.section("encoder"):
Get $z$, $\mu$, and $\hat{\sigma}$
561 z, mu, sigma_hat = self.encoder(data)
Decode the mixture of distributions and $\hat{q}$
564 with monit.section("decoder"):
Concatenate $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$
566 z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
567 inputs = torch.cat([data[:-1], z_stack], 2)
Get mixture of distributions and $\hat{q}$
569 dist, q_logits, _ = self.decoder(inputs, z, None)
Compute the loss
572 with monit.section('loss'):
$L_{KL}$
574 kl_loss = self.kl_div_loss(sigma_hat, mu)
$L_R$
576 reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
$Loss = L_R + w_{KL} L_{KL}$
578 loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss
Track losses
581 tracker.add("loss.kl.", kl_loss)
582 tracker.add("loss.reconstruction.", reconstruction_loss)
583 tracker.add("loss.total.", loss)
Only if we are in training state
586 if self.mode.is_train:
Run optimizer
588 with monit.section('optimize'):
Set grad
to zero
590 self.optimizer.zero_grad()
Compute gradients
592 loss.backward()
Log model parameters and gradients
594 if batch_idx.is_last:
595 tracker.add(encoder=self.encoder, decoder=self.decoder)
Clip gradients
597 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
598 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
Optimize
600 self.optimizer.step()
601
602 tracker.save()
604 def sample(self):
Randomly pick a sample from validation dataset to encoder
606 data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
Add batch dimension and move it to device
608 data = data.unsqueeze(1).to(self.device)
Sample
610 self.sampler.sample(data, self.temperature)
613def main():
614 configs = Configs()
615 experiment.create(name="sketch_rnn")
Pass a dictionary of configurations
618 experiment.configs(configs, {
619 'optimizer.optimizer': 'Adam',
We use a learning rate of 1e-3
because we can see results faster.
Paper had suggested 1e-4
.
622 'optimizer.learning_rate': 1e-3,
Name of the dataset
624 'dataset_name': 'bicycle',
Number of inner iterations within an epoch to switch between training, validation and sampling.
626 'inner_iterations': 10
627 })
628
629 with experiment.start():
Run the experiment
631 configs.run()
632
633
634if __name__ == "__main__":
635 main()