Sketch RNN

This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.

Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian’s.

Getting data

Download data from Quick, Draw! Dataset. There is a link to download npz files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz file(s) in data/sketch folder. This code is configured to use bicycle dataset. You can change this in configurations.

Acknowledgements

Took help from PyTorch Sketch RNN project by Alexis David Jacq

32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex

Dataset

This class load and pre-process the data.

50class StrokesDataset(Dataset):

dataset is a list of numpy arrays of shape [seq_len, 3]. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y ($\Delta x$, $\Delta y$) And the last integer represents the state of the pen - $1$ if it’s touching the paper and $0$ otherwise.

57    def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67        data = []

We iterate through each of the sequences and filter

69        for seq in dataset:

Filter if the length of the the sequence of strokes is within our range

71            if 10 < len(seq) <= max_seq_length:

Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$

73                seq = np.minimum(seq, 1000)
74                seq = np.maximum(seq, -1000)

Convert to a floating point array and add to data

76                seq = np.array(seq, dtype=np.float32)
77                data.append(seq)

We then calculate the scaling factor which is the standard deviation of ($\Delta x$, $\Delta y$) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to $0$.

83        if scale is None:
84            scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85        self.scale = scale

Get the longest sequence length among all sequences

88        longest_seq_len = max([len(seq) for seq in data])

We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector $(\Delta x, \Delta y, p_1, p_2, p_3)$. Only one of $p_1, p_2, p_3$ is $1$ and the others are $0$. They represent pen down, pen up and end-of-sequence in that order. $p_1$ is $1$ if the pen touches the paper in the next step. $p_2$ is $1$ if the pen doesn’t touch the paper in the next step. $p_3$ is $1$ if it is the end of the drawing.

98        self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)

The mask array is needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1] and predicts next step.

101        self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103        for i, seq in enumerate(data):
104            seq = torch.from_numpy(seq)
105            len_seq = len(seq)

Scale and set $\Delta x, \Delta y$

107            self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale

$p_1$

109            self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]

$p_2$

111            self.data[i, 1:len_seq + 1, 3] = seq[:, 2]

$p_3$

113            self.data[i, len_seq + 1:, 4] = 1

Mask is on until end of sequence

115            self.mask[i, :len_seq + 1] = 1

Start-of-sequence is $(0, 0, 1, 0, 0)

118        self.data[:, 0, 2] = 1

Size of the dataset

120    def __len__(self):
122        return len(self.data)

Get a sample

124    def __getitem__(self, idx: int):
126        return self.data[idx], self.mask[idx]

Bi-variate Gaussian mixture

The mixture is represented by $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$. This class adjust temperatures and creates the categorical and gaussian distributions from the parameters.

129class BivariateGaussianMixture:
139    def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140                 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141        self.pi_logits = pi_logits
142        self.mu_x = mu_x
143        self.mu_y = mu_y
144        self.sigma_x = sigma_x
145        self.sigma_y = sigma_y
146        self.rho_xy = rho_xy

Number of distributions in the mixture, $M$

148    @property
149    def n_distributions(self):
151        return self.pi_logits.shape[-1]

Adjust by temperature $\tau$

153    def set_temperature(self, temperature: float):

158        self.pi_logits /= temperature

160        self.sigma_x *= math.sqrt(temperature)

162        self.sigma_y *= math.sqrt(temperature)
164    def get_distribution(self):

Clamp $\sigma_x$, $\sigma_y$ and $\rho_{xy}$ to avoid getting NaNs

166        sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167        sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168        rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)

Get means

171        mean = torch.stack([self.mu_x, self.mu_y], -1)

Get covariance matrix

173        cov = torch.stack([
174            sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175            rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176        ], -1)
177        cov = cov.view(*sigma_y.shape, 2, 2)

Create bi-variate normal distribution.

📝 It would be efficient to scale_tril matrix as [[a, 0], [b, c]] where . But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function.

188        multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)

Create categorical distribution $\Pi$ from logits

191        cat_dist = torch.distributions.Categorical(logits=self.pi_logits)
194        return cat_dist, multi_dist

Encoder module

This consists of a bidirectional LSTM

197class EncoderRNN(Module):
204    def __init__(self, d_z: int, enc_hidden_size: int):
205        super().__init__()

Create a bidirectional LSTM takes a sequence of $(\Delta x, \Delta y, p_1, p_2, p_3)$ as input.

208        self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)

Head to get $\mu$

210        self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)

Head to get $\hat{\sigma}$

212        self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214    def __call__(self, inputs: torch.Tensor, state=None):

The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and and first token in the reverse direction. Which is what we want.

222        _, (hidden, cell) = self.lstm(inputs.float(), state)

The state has shape [2, batch_size, hidden_size] where the first dimension is the direction. We rearrange it to get $h = [h_{\rightarrow}; h_{\leftarrow}]$

226        hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')

$\mu$

229        mu = self.mu_head(hidden)

$\hat{\sigma}$

231        sigma_hat = self.sigma_head(hidden)

$\sigma = \exp(\frac{\hat{\sigma}}{2})$

233        sigma = torch.exp(sigma_hat / 2.)

Sample $z = \mu + \sigma \cdot \mathcal{N}(0, I)$

236        z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
239        return z, mu, sigma_hat

Decoder module

This consists of a LSTM

242class DecoderRNN(Module):
249    def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
250        super().__init__()

LSTM takes $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$ as input

252        self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)

Initial state of the LSTM is $[h_0; c_0] = \tanh(W_{z}z + b_z)$. init_state is the linear transformation for this

256        self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)

This layer produces outputs for each of of the n_distributions. Each distribution needs six parameters $(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$

261        self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)

This head is for the logits $(\hat{q_1}, \hat{q_2}, \hat{q_3})$

264        self.q_head = nn.Linear(dec_hidden_size, 3)

This is to calculate $\log(q_k)$ where

267        self.q_log_softmax = nn.LogSoftmax(-1)

These parameters are stored for future reference

270        self.n_distributions = n_distributions
271        self.dec_hidden_size = dec_hidden_size
273    def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):

Calculate the initial state

275        if state is None:

$[h_0; c_0] = \tanh(W_{z}z + b_z)$

277            h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)

h and c have shapes [batch_size, lstm_size]. We want to make them to shape [1, batch_size, lstm_size] because that’s the shape used in LSTM.

280            state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())

Run the LSTM

283        outputs, state = self.lstm(x, state)

Get $\log(q)$

286        q_logits = self.q_log_softmax(self.q_head(outputs))

Get $(\hat{\Pi_i}, \mu_{x,i}, \mu_{y,i}, \hat{\sigma_{x,i}}, \hat{\sigma_{y,i}} \hat{\rho_{xy,i}})$. torch.split splits the output into 6 tensors of size self.n_distribution across dimension 2.

292        pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
293            torch.split(self.mixtures(outputs), self.n_distributions, 2)

Create a bi-variate gaussian mixture $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$ where and

$\Pi$ is the categorical probabilities of choosing the distribution out of the mixture $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.

306        dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
307                                        torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
310        return dist, q_logits, state

Reconstruction Loss

313class ReconstructionLoss(Module):
318    def __call__(self, mask: torch.Tensor, target: torch.Tensor,
319                 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):

Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$

321        pi, mix = dist.get_distribution()

target has shape [seq_len, batch_size, 5] where the last dimension is the features $(\Delta x, \Delta y, p_1, p_2, p_3)$. We want to get $\Delta x, \Delta$ and get the probabilities from each of the distributions in the mixture $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$.

xy will have shape [seq_len, batch_size, n_distributions, 2]

328        xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)

Calculate the probabilities

334        probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)

Although probs has $N_{max}$ (longest_seq_len) elements the sum is only taken upto $N_s$ because the rest are masked out.

It might feel like we should be taking the sum and dividing by $N_s$ and not $N_{max}$, but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction $p(\Delta x, \Delta y)$ when we divide by $N_{max}$

343        loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))

346        loss_pen = -torch.mean(target[:, :, 2:] * q_logits)

349        return loss_stroke + loss_pen

KL-Divergence loss

This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$

352class KLDivLoss(Module):
359    def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor):

361        return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))

Sampler

This samples a sketch from the decoder and plots it

364class Sampler:
371    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
372        self.decoder = decoder
373        self.encoder = encoder
375    def sample(self, data: torch.Tensor, temperature: float):

$N_{max}$

377        longest_seq_len = len(data)

Get $z$ from the encoder

380        z, _, _ = self.encoder(data)

Start-of-sequence stroke is $(0, 0, 1, 0, 0)$

383        s = data.new_tensor([0, 0, 1, 0, 0])
384        seq = [s]

Initial decoder is None. The decoder will initialize it to $[h_0; c_0] = \tanh(W_{z}z + b_z)$

387        state = None

We don’t need gradients

390        with torch.no_grad():

Sample $N_{max}$ strokes

392            for i in range(longest_seq_len):

$[(\Delta x, \Delta y, p_1, p_2, p_3); z] is the input to the decoder$

394                data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)

Get $\Pi$, $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$, $q$ and the next state from the decoder

397                dist, q_logits, state = self.decoder(data, z, state)

Sample a stroke

399                s = self._sample_step(dist, q_logits, temperature)

Add the new stroke to the sequence of strokes

401                seq.append(s)

Stop sampling if $p_3 = 1$. This indicates that sketching has stopped

403                if s[4] == 1:
404                    break

Create a PyTorch tensor of the sequence of strokes

407        seq = torch.stack(seq)

Plot the sequence of strokes

410        self.plot(seq)
412    @staticmethod
413    def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):

Set temperature $\tau$ for sampling. This is implemented in class BivariateGaussianMixture.

415        dist.set_temperature(temperature)

Get temperature adjusted $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$

417        pi, mix = dist.get_distribution()

Sample from $\Pi$ the index of the distribution to use from the mixture

419        idx = pi.sample()[0, 0]

Create categorical distribution $q$ with log-probabilities q_logits or $\hat{q}$

422        q = torch.distributions.Categorical(logits=q_logits / temperature)

Sample from $q$

424        q_idx = q.sample()[0, 0]

Sample from the normal distributions in the mixture and pick the one indexed by idx

427        xy = mix.sample()[0, 0, idx]

Create an empty stroke $(\Delta x, \Delta y, q_1, q_2, q_3)$

430        stroke = q_logits.new_zeros(5)

Set $\Delta x, \Delta y$

432        stroke[:2] = xy

Set $q_1, q_2, q_3$

434        stroke[q_idx + 2] = 1
436        return stroke
438    @staticmethod
439    def plot(seq: torch.Tensor):

Take the cumulative sums of $(\Delta x, \Delta y)$ to get $$x, y)$

441        seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)

Create a new numpy array of the form $(x, y, q_2)$

443        seq[:, 2] = seq[:, 3]
444        seq = seq[:, 0:3].detach().cpu().numpy()

Split the array at points where $q_2$ is $1$. That is split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.

449        strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)

Plot each sequence of strokes

451        for s in strokes:
452            plt.plot(s[:, 0], -s[:, 1])

Don’t show axes

454        plt.axis('off')

Show the plot

456        plt.show()

Configurations

These are default configurations which can be later adjusted by passing a dict.

459class Configs(TrainValidConfigs):

Device configurations to pick the device to run the experiment

467    device: torch.device = DeviceConfigs()
469    encoder: EncoderRNN
470    decoder: DecoderRNN
471    optimizer: optim.Adam
472    sampler: Sampler
473
474    dataset_name: str
475    train_loader: DataLoader
476    valid_loader: DataLoader
477    train_dataset: StrokesDataset
478    valid_dataset: StrokesDataset

Encoder and decoder sizes

481    enc_hidden_size = 256
482    dec_hidden_size = 512

Batch size

485    batch_size = 100

Number of features in $z$

488    d_z = 128

Number of distributions in the mixture, $M$

490    n_distributions = 20

Weight of KL divergence loss, $w_{KL}$

493    kl_div_loss_weight = 0.5

Gradient clipping

495    grad_clip = 1.

Temperature $\tau$ for sampling

497    temperature = 0.4

Filter out stroke sequences longer than $200$

500    max_seq_length = 200
501
502    epochs = 100
503
504    kl_div_loss = KLDivLoss()
505    reconstruction_loss = ReconstructionLoss()
507    def init(self):

Initialize encoder & decoder

509        self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
510        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)

Set optimizer, things like type of optimizer and learning rate are configurable

513        optimizer = OptimizerConfigs()
514        optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
515        self.optimizer = optimizer

Create sampler

518        self.sampler = Sampler(self.encoder, self.decoder)

npz file path is data/sketch/[DATASET NAME].npz

521        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'

Load the numpy file.

523        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)

Create training dataset

526        self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)

Create validation dataset

528        self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)

Create training data loader

531        self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)

Create validation data loader

533        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)

Add hooks to monitor layer outputs on Tensorboard

536        hook_model_outputs(self.mode, self.encoder, 'encoder')
537        hook_model_outputs(self.mode, self.decoder, 'decoder')

Configure the tracker to print the total train/validation loss

540        tracker.set_scalar("loss.total.*", True)
541
542        self.state_modules = []
544    def step(self, batch: Any, batch_idx: BatchIndex):
545        self.encoder.train(self.mode.is_train)
546        self.decoder.train(self.mode.is_train)

Move data and mask to device and swap the sequence and batch dimensions. data will have shape [seq_len, batch_size, 5] and mask will have shape [seq_len, batch_size].

551        data = batch[0].to(self.device).transpose(0, 1)
552        mask = batch[1].to(self.device).transpose(0, 1)

Increment step in training mode

555        if self.mode.is_train:
556            tracker.add_global_step(len(data))

Encode the sequence of strokes

559        with monit.section("encoder"):

Get $z$, $\mu$, and $\hat{\sigma}$

561            z, mu, sigma_hat = self.encoder(data)

Decode the mixture of distributions and $\hat{q}$

564        with monit.section("decoder"):

Concatenate $[(\Delta x, \Delta y, p_1, p_2, p_3); z]$

566            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
567            inputs = torch.cat([data[:-1], z_stack], 2)

Get mixture of distributions and $\hat{q}$

569            dist, q_logits, _ = self.decoder(inputs, z, None)

Compute the loss

572        with monit.section('loss'):

$L_{KL}$

574            kl_loss = self.kl_div_loss(sigma_hat, mu)

$L_R$

576            reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)

$Loss = L_R + w_{KL} L_{KL}$

578            loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss

Track losses

581            tracker.add("loss.kl.", kl_loss)
582            tracker.add("loss.reconstruction.", reconstruction_loss)
583            tracker.add("loss.total.", loss)

Only if we are in training state

586        if self.mode.is_train:

Run optimizer

588            with monit.section('optimize'):

Set grad to zero

590                self.optimizer.zero_grad()

Compute gradients

592                loss.backward()

Log model parameters and gradients

594                if batch_idx.is_last:
595                    tracker.add(encoder=self.encoder, decoder=self.decoder)

Clip gradients

597                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
598                nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

Optimize

600                self.optimizer.step()
601
602        tracker.save()
604    def sample(self):

Randomly pick a sample from validation dataset to encoder

606        data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]

Add batch dimension and move it to device

608        data = data.unsqueeze(1).to(self.device)

Sample

610        self.sampler.sample(data, self.temperature)
613def main():
614    configs = Configs()
615    experiment.create(name="sketch_rnn")

Pass a dictionary of configurations

618    experiment.configs(configs, {
619        'optimizer.optimizer': 'Adam',

We use a learning rate of 1e-3 because we can see results faster. Paper had suggested 1e-4.

622        'optimizer.learning_rate': 1e-3,

Name of the dataset

624        'dataset_name': 'bicycle',

Number of inner iterations within an epoch to switch between training, validation and sampling.

626        'inner_iterations': 10
627    })
628
629    with experiment.start():

Run the experiment

631        configs.run()
632
633
634if __name__ == "__main__":
635    main()