This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.
Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian's.
Download data from Quick, Draw! Dataset. There is a link to download npz
 files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz
 file(s) in data/sketch
 folder. This code is configured to use bicycle
 dataset. You can change this in configurations.
Took help from PyTorch Sketch RNN project by Alexis David Jacq
32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex50class StrokesDataset(Dataset): dataset
 is a list of numpy arrays of shape seq_len, 3. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y (, ) and the last integer represents the state of the pen,  if it's touching the paper and  otherwise.
57    def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):67        data = []We iterate through each of the sequences and filter
69        for seq in dataset:Filter if the length of the sequence of strokes is within our range
71            if 10 < len(seq) <= max_seq_length:Clamp , to
73                seq = np.minimum(seq, 1000)
74                seq = np.maximum(seq, -1000)Convert to a floating point array and add to data
 
76                seq = np.array(seq, dtype=np.float32)
77                data.append(seq)We then calculate the scaling factor which is the standard deviation of (, ) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to .
83        if scale is None:
84            scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85        self.scale = scaleGet the longest sequence length among all sequences
88        longest_seq_len = max([len(seq) for seq in data])We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector . Only one of is and the others are . They represent pen down, pen up and end-of-sequence in that order. is if the pen touches the paper in the next step. is if the pen doesn't touch the paper in the next step. is if it is the end of the drawing.
98        self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)The mask array needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1]
 and predicts next step. 
101        self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103        for i, seq in enumerate(data):
104            seq = torch.from_numpy(seq)
105            len_seq = len(seq)Scale and set
107            self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale109            self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]111            self.data[i, 1:len_seq + 1, 3] = seq[:, 2]113            self.data[i, len_seq + 1:, 4] = 1Mask is on until end of sequence
115            self.mask[i, :len_seq + 1] = 1Start-of-sequence is
118        self.data[:, 0, 2] = 1Size of the dataset
120    def __len__(self):122        return len(self.data)Get a sample
124    def __getitem__(self, idx: int):126        return self.data[idx], self.mask[idx]The mixture is represented by and . This class adjusts temperatures and creates the categorical and Gaussian distributions from the parameters.
129class BivariateGaussianMixture:139    def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140                 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141        self.pi_logits = pi_logits
142        self.mu_x = mu_x
143        self.mu_y = mu_y
144        self.sigma_x = sigma_x
145        self.sigma_y = sigma_y
146        self.rho_xy = rho_xyNumber of distributions in the mixture,
148    @property
149    def n_distributions(self):151        return self.pi_logits.shape[-1]Adjust by temperature
153    def set_temperature(self, temperature: float):158        self.pi_logits /= temperature160        self.sigma_x *= math.sqrt(temperature)162        self.sigma_y *= math.sqrt(temperature)164    def get_distribution(self):Clamp ,  and  to avoid getting NaN
s 
166        sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167        sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168        rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)Get means
171        mean = torch.stack([self.mu_x, self.mu_y], -1)Get covariance matrix
173        cov = torch.stack([
174            sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175            rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176        ], -1)
177        cov = cov.view(*sigma_y.shape, 2, 2)Create bi-variate normal distribution.
📝 It would be efficient to scale_tril
 matrix as [[a, 0], [b, c]]
 where . But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function. 
188        multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)Create categorical distribution from logits
191        cat_dist = torch.distributions.Categorical(logits=self.pi_logits)194        return cat_dist, multi_dist197class EncoderRNN(Module):204    def __init__(self, d_z: int, enc_hidden_size: int):
205        super().__init__()Create a bidirectional LSTM taking a sequence of as input.
208        self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)Head to get
210        self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)Head to get
212        self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)214    def forward(self, inputs: torch.Tensor, state=None):The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and first token in the reverse direction, which is what we want.
221        _, (hidden, cell) = self.lstm(inputs.float(), state)The state has shape [2, batch_size, hidden_size]
, where the first dimension is the direction. We rearrange it to get  
225        hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')228        mu = self.mu_head(hidden)230        sigma_hat = self.sigma_head(hidden)232        sigma = torch.exp(sigma_hat / 2.)Sample
235        z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))238        return z, mu, sigma_hat241class DecoderRNN(Module):248    def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249        super().__init__()LSTM takes as input
251        self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)Initial state of the LSTM is . init_state
 is the linear transformation for this 
255        self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)This layer produces outputs for each of the n_distributions
. Each distribution needs six parameters  
260        self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)This head is for the logits
263        self.q_head = nn.Linear(dec_hidden_size, 3)This is to calculate where
266        self.q_log_softmax = nn.LogSoftmax(-1)These parameters are stored for future reference
269        self.n_distributions = n_distributions
270        self.dec_hidden_size = dec_hidden_size272    def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):Calculate the initial state
274        if state is None:276            h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)h
 and c
 have shapes [batch_size, lstm_size]
. We want to shape them to [1, batch_size, lstm_size]
 because that's the shape used in LSTM. 
279            state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())Run the LSTM
282        outputs, state = self.lstm(x, state)Get
285        q_logits = self.q_log_softmax(self.q_head(outputs))Get . torch.split
 splits the output into 6 tensors of size self.n_distribution
 across dimension 2
. 
291        pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292            torch.split(self.mixtures(outputs), self.n_distributions, 2)Create a bi-variate Gaussian mixture and where and
is the categorical probabilities of choosing the distribution out of the mixture .
305        dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306                                        torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))309        return dist, q_logits, state312class ReconstructionLoss(Module):317    def forward(self, mask: torch.Tensor, target: torch.Tensor,
318                 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):Get and
320        pi, mix = dist.get_distribution()target
 has shape [seq_len, batch_size, 5]
 where the last dimension is the features . We want to get  y and get the probabilities from each of the distributions in the mixture .
xy
 will have shape [seq_len, batch_size, n_distributions, 2]
 
327        xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)Calculate the probabilities
333        probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2) Although probs
 has  (longest_seq_len
) elements, the sum is only taken upto  because the rest is masked out.
It might feel like we should be taking the sum and dividing by and not , but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction when we divide by
342        loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))345        loss_pen = -torch.mean(target[:, :, 2:] * q_logits)348        return loss_stroke + loss_pen351class KLDivLoss(Module):358    def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):360        return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))363class Sampler:370    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371        self.decoder = decoder
372        self.encoder = encoder374    def sample(self, data: torch.Tensor, temperature: float):376        longest_seq_len = len(data)Get from the encoder
379        z, _, _ = self.encoder(data)Start-of-sequence stroke is
382        s = data.new_tensor([0, 0, 1, 0, 0])
383        seq = [s]Initial decoder is None
. The decoder will initialize it to  
386        state = NoneWe don't need gradients
389        with torch.no_grad():Sample strokes
391            for i in range(longest_seq_len):is the input to the decoder
393                data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)Get , , and the next state from the decoder
396                dist, q_logits, state = self.decoder(data, z, state)Sample a stroke
398                s = self._sample_step(dist, q_logits, temperature)Add the new stroke to the sequence of strokes
400                seq.append(s)Stop sampling if . This indicates that sketching has stopped
402                if s[4] == 1:
403                    breakCreate a PyTorch tensor of the sequence of strokes
406        seq = torch.stack(seq)Plot the sequence of strokes
409        self.plot(seq)411    @staticmethod
412    def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):Set temperature  for sampling. This is implemented in class BivariateGaussianMixture
. 
414        dist.set_temperature(temperature)Get temperature adjusted and
416        pi, mix = dist.get_distribution()Sample from the index of the distribution to use from the mixture
418        idx = pi.sample()[0, 0]Create categorical distribution  with log-probabilities q_logits
 or  
421        q = torch.distributions.Categorical(logits=q_logits / temperature)Sample from
423        q_idx = q.sample()[0, 0]Sample from the normal distributions in the mixture and pick the one indexed by idx
 
426        xy = mix.sample()[0, 0, idx]Create an empty stroke
429        stroke = q_logits.new_zeros(5)Set
431        stroke[:2] = xySet
433        stroke[q_idx + 2] = 1435        return stroke437    @staticmethod
438    def plot(seq: torch.Tensor):Take the cumulative sums of to get
440        seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)Create a new numpy array of the form
442        seq[:, 2] = seq[:, 3]
443        seq = seq[:, 0:3].detach().cpu().numpy()Split the array at points where is . i.e. split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.
448        strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)Plot each sequence of strokes
450        for s in strokes:
451            plt.plot(s[:, 0], -s[:, 1])Don't show axes
453        plt.axis('off')Show the plot
455        plt.show()458class Configs(TrainValidConfigs):Device configurations to pick the device to run the experiment
466    device: torch.device = DeviceConfigs()468    encoder: EncoderRNN
469    decoder: DecoderRNN
470    optimizer: optim.Adam
471    sampler: Sampler
472
473    dataset_name: str
474    train_loader: DataLoader
475    valid_loader: DataLoader
476    train_dataset: StrokesDataset
477    valid_dataset: StrokesDatasetEncoder and decoder sizes
480    enc_hidden_size = 256
481    dec_hidden_size = 512Batch size
484    batch_size = 100Number of features in
487    d_z = 128Number of distributions in the mixture,
489    n_distributions = 20Weight of KL divergence loss,
492    kl_div_loss_weight = 0.5Gradient clipping
494    grad_clip = 1.Temperature for sampling
496    temperature = 0.4Filter out stroke sequences longer than
499    max_seq_length = 200
500
501    epochs = 100
502
503    kl_div_loss = KLDivLoss()
504    reconstruction_loss = ReconstructionLoss()506    def init(self):Initialize encoder & decoder
508        self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)Set optimizer. Things like type of optimizer and learning rate are configurable
512        optimizer = OptimizerConfigs()
513        optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514        self.optimizer = optimizerCreate sampler
517        self.sampler = Sampler(self.encoder, self.decoder)npz
 file path is data/sketch/[DATASET NAME].npz
 
520        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'Load the numpy file
522        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)Create training dataset
525        self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)Create validation dataset
527        self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)Create training data loader
530        self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)Create validation data loader
532        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)Add hooks to monitor layer outputs on Tensorboard
535        hook_model_outputs(self.mode, self.encoder, 'encoder')
536        hook_model_outputs(self.mode, self.decoder, 'decoder')Configure the tracker to print the total train/validation loss
539        tracker.set_scalar("loss.total.*", True)
540
541        self.state_modules = []543    def step(self, batch: Any, batch_idx: BatchIndex):
544        self.encoder.train(self.mode.is_train)
545        self.decoder.train(self.mode.is_train)Move data
 and mask
 to device and swap the sequence and batch dimensions. data
 will have shape [seq_len, batch_size, 5]
 and mask
 will have shape [seq_len, batch_size]
. 
550        data = batch[0].to(self.device).transpose(0, 1)
551        mask = batch[1].to(self.device).transpose(0, 1)Increment step in training mode
554        if self.mode.is_train:
555            tracker.add_global_step(len(data))Encode the sequence of strokes
558        with monit.section("encoder"):Get , , and
560            z, mu, sigma_hat = self.encoder(data)Decode the mixture of distributions and
563        with monit.section("decoder"):Concatenate
565            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566            inputs = torch.cat([data[:-1], z_stack], 2)Get mixture of distributions and
568            dist, q_logits, _ = self.decoder(inputs, z, None)Compute the loss
571        with monit.section('loss'):573            kl_loss = self.kl_div_loss(sigma_hat, mu)575            reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)577            loss = reconstruction_loss + self.kl_div_loss_weight * kl_lossTrack losses
580            tracker.add("loss.kl.", kl_loss)
581            tracker.add("loss.reconstruction.", reconstruction_loss)
582            tracker.add("loss.total.", loss)Only if we are in training state
585        if self.mode.is_train:Run optimizer
587            with monit.section('optimize'):Set grad
 to zero 
589                self.optimizer.zero_grad()Compute gradients
591                loss.backward()Log model parameters and gradients
593                if batch_idx.is_last:
594                    tracker.add(encoder=self.encoder, decoder=self.decoder)Clip gradients
596                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597                nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)Optimize
599                self.optimizer.step()
600
601        tracker.save()603    def sample(self):Randomly pick a sample from validation dataset to encoder
605        data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]Add batch dimension and move it to device
607        data = data.unsqueeze(1).to(self.device)Sample
609        self.sampler.sample(data, self.temperature)612def main():
613    configs = Configs()
614    experiment.create(name="sketch_rnn")Pass a dictionary of configurations
617    experiment.configs(configs, {
618        'optimizer.optimizer': 'Adam',We use a learning rate of 1e-3
 because we can see results faster. Paper had suggested 1e-4
. 
621        'optimizer.learning_rate': 1e-3,Name of the dataset
623        'dataset_name': 'bicycle',Number of inner iterations within an epoch to switch between training, validation and sampling.
625        'inner_iterations': 10
626    })
627
628    with experiment.start():Run the experiment
630        configs.run()
631
632
633if __name__ == "__main__":
634    main()