diff --git a/labml_nn/sketch_rnn/__init__.py b/labml_nn/sketch_rnn/__init__.py index b031d09c..9f621afb 100644 --- a/labml_nn/sketch_rnn/__init__.py +++ b/labml_nn/sketch_rnn/__init__.py @@ -42,16 +42,17 @@ class StrokesDataset(Dataset): """ def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None): - # Filter and convert training sequences to floats. + """ + `dataset` is a list of numpy arrays of shape [seq_len, 3]. + It is a sequence of strokes, and each stroke is represented by + 3 integers. + First two are the displacements along x and y ($\Delta x$, $\Delta y$) + And the last integer represents the state of the pen - $1$ if it's touching + the paper and $0$ otherwise. + """ + data = [] - # `dataset['train']` is a list of numpy arrays of shape [seq_len, 3]. - # It is a sequence of strokes, and each stroke is represented by - # 3 integers. - # First two are the displacements along x and y ($\Delta x$, $\Delta y$) - # And the last integer represents the state of the pen - 1 if it's touching - # the paper and 0 otherwise. - # - # We iterate through each of the sequences + # We iterate through each of the sequences and filter for seq in dataset: # Filter if the length of the the sequence of strokes is within our range if 10 < len(seq) <= max_seq_length: @@ -62,103 +63,163 @@ class StrokesDataset(Dataset): seq = np.array(seq, dtype=np.float32) data.append(seq) - # We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation. - # This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined. + # We then calculate the scaling factor which is the + # standard deviation of ($\Delta x$, $\Delta y$) combined. # Paper notes that the mean is not adjusted for simplicity, # since the mean is anyway close to $0$. if scale is None: scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data])) self.scale = scale - for s in data: - # Adjust by standard deviation - s[:, 0:2] /= scale # Get the longest sequence length among all sequences longest_seq_len = max([len(seq) for seq in data]) - # Initialize PyTorch data array + # We initialize PyTorch data array with two extra steps for start-of-sequence (sos) + # and end-of-sequence (eos). + # Each step is a vector $(\Delta x, \Delta y, p_1, p_2, p_3)$. + # Only one of $p_1, p_2, p_3$ is $1$ and the others are $0$. + # They represent *pen down*, *pen up* and *end-of-sequence* in that order. + # $p_1$ is $1$ if the pen touches the paper in the next step. + # $p_2$ is $1$ if the pen doesn't touch the paper in the next step. + # $p_2$ is $1$ if it is the end of the drawing. self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float) - # Initialize mask array. Mask has an extra step because the model predicts - # end of sequence at the end. + # The mask array is needs only one extra-step since it is for the outputs of the + # decoder, which takes in `data[:-1]` and predicts next step. self.mask = torch.zeros(len(data), longest_seq_len + 1) for i, seq in enumerate(data): seq = torch.from_numpy(seq) len_seq = len(seq) - # set x, y - self.data[i, 1:len_seq + 1, :2] = seq[:, :2] - # set pen status + # Scale and set $\Delta x, \Delta y$ + self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale + # $p_1$ self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2] + # $p_2$ self.data[i, 1:len_seq + 1, 3] = seq[:, 2] + # $p_3$ self.data[i, len_seq + 1:, 4] = 1 + # Mask is on until end of sequence self.mask[i, :len_seq + 1] = 1 + # Start-of-sequence is $(0, 0, 1, 0, 0) self.data[:, 0, 2] = 1 def __len__(self): + """Size of the dataset""" return len(self.data) def __getitem__(self, idx: int): + """Get a sample""" return self.data[idx], self.mask[idx] class EncoderRNN(Module): + """ + ## Encoder module + + This consists of a bidirectional LSTM + """ def __init__(self, d_z: int, enc_hidden_size: int): super().__init__() + # Create a bidirectional LSTM takes a sequence of + # $(\Delta x, \Delta y, p_1, p_2, p_3)$ as input. self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True) + # Head to get $\mu$ self.mu_head = nn.Linear(2 * enc_hidden_size, d_z) + # Head to get $\hat{\sigma}$ self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z) def __call__(self, inputs: torch.Tensor, state=None): + # The hidden state of the bidirectional LSTM is the concatenation of the + # output of the last token in the forward direction and + # and first token in the reverse direction. + # Which is what we want. + # $$h_{\rightarrow} = encode_{\rightarrow}(S), + # h_{\leftarrow} = encode←_{\leftarrow}(S_{reverse}), + # h = [h_{\rightarrow}; h_{\leftarrow}]$$ _, (hidden, cell) = self.lstm(inputs.float(), state) + # The state has shape `[2, batch_size, hidden_size]` + # where the first dimension is the direction. + # We rearrange it to get $h = [h_{\rightarrow}; h_{\leftarrow}]$ hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)') + # $\mu$ mu = self.mu_head(hidden) + # $\hat{\sigma}$ sigma_hat = self.sigma_head(hidden) + # $\sigma = \exp(\frac{\hat{\sigma}}{2})$ sigma = torch.exp(sigma_hat / 2.) - z_size = mu.size() - z = mu + sigma * torch.normal(mu.new_zeros(z_size), mu.new_ones(z_size)) + # Sample $z = \mu + \sigma \cdot \mathcal{N}(0, I)$ + z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape)) return z, mu, sigma_hat class DecoderRNN(Module): - def __init__(self, d_z: int, dec_hidden_size: int, n_mixtures: int): - super().__init__() - self.dec_hidden_size = dec_hidden_size - self.init_state = nn.Linear(d_z, 2 * dec_hidden_size) + """ + ## Encoder module + This consists of a LSTM + """ + def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int): + super().__init__() + # LSTM takes $[z; (\Delta x, \Delta y, p_1, p_2, p_3)$ as input self.lstm = nn.LSTM(d_z + 5, dec_hidden_size) - self.mixtures = nn.Linear(dec_hidden_size, 6 * n_mixtures) + # Initial state of the LSTM is $[h_0; c_0] = \tanh(W_{z}z + b_z)$. + # `init_state` is the linear transformation for this + self.init_state = nn.Linear(d_z, 2 * dec_hidden_size) + + # This layer produces outputs for each of of the `n_distributions`. + # Each distribution needs six parameters + # $(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$ + self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions) + + # This head is for the logits $(\hat{q_1}, \hat{q_2}, \hat{q_3})$ self.q_head = nn.Linear(dec_hidden_size, 3) - self.n_mixtures = n_mixtures + # This is to calculate $\log(q_k)$ where + # $$q_k = \frac{\exp(\hat{q_k}}{\sum_{j = 1}^3 \exp(\hat{q_j}}$$ self.q_log_softmax = nn.LogSoftmax(-1) + # These parameters are stored for future reference + self.n_distributions = n_distributions + self.dec_hidden_size = dec_hidden_size + def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): + # Calculate the initial state if state is None: - hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1) - state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) + # $[h_0; c_0] = \tanh(W_{z}z + b_z)$ + h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1) + # `h` and `c` have shapes `[batch_size, lstm_size]`. We want to make them + # to shape `[1, batch_size, lstm_size]` because that's the shape used in LSTM. + state = (h.unsqueeze(0), c.unsqueeze(0)) - outputs, (hidden, cell) = self.lstm(x, state) + # Run the LSTM + outputs, state = self.lstm(x, state) + # Get $\log(q)$ q_logits = self.q_log_softmax(self.q_head(outputs)) + # Get $(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, + # \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$. + # `torch.split` splits the output into 6 tensors of size `self.n_distribution` + # across dimension `2`. pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \ - torch.split(self.mixtures(outputs), self.n_mixtures, 2) + torch.split(self.mixtures(outputs), self.n_distributions, 2) + # Create a bivariate gaussian mixture dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y, torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy)) - return dist, q_logits, (hidden, cell) + return dist, q_logits, state class ReconstructionLoss(Module): def __call__(self, mask: torch.Tensor, target: torch.Tensor, dist: 'BivariateGaussianMixture', q_logits: torch.Tensor): pi, mix = dist.get_distribution() - xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_mixtures, -1) + xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1) probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2) loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs)) loss_pen = -torch.mean(target[:, :, 2:] * q_logits) @@ -244,7 +305,7 @@ class Configs(TrainValidConfigs): batch_size = 100 d_z = 128 - n_mixtures = 20 + n_distributions = 20 kl_div_loss_weight = 0.5 grad_clip = 1. @@ -264,7 +325,7 @@ class Configs(TrainValidConfigs): Configs.valid_dataset, Configs.valid_loader]) def setup_all(self: Configs): self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) - self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device) + self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) self.optimizer = OptimizerConfigs() self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) @@ -359,7 +420,7 @@ class BivariateGaussianMixture: self.rho_xy = rho_xy @property - def n_mixtures(self): + def n_distributions(self): return self.pi_logits.shape[-1] def set_temperature(self, temperature: float):