📚 encoder/decoder

This commit is contained in:
Varuna Jayasiri
2020-10-19 16:29:48 +05:30
parent 73bee48be9
commit 17e26b0bd1

View File

@ -42,16 +42,17 @@ class StrokesDataset(Dataset):
"""
def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
# Filter and convert training sequences to floats.
"""
`dataset` is a list of numpy arrays of shape [seq_len, 3].
It is a sequence of strokes, and each stroke is represented by
3 integers.
First two are the displacements along x and y ($\Delta x$, $\Delta y$)
And the last integer represents the state of the pen - $1$ if it's touching
the paper and $0$ otherwise.
"""
data = []
# `dataset['train']` is a list of numpy arrays of shape [seq_len, 3].
# It is a sequence of strokes, and each stroke is represented by
# 3 integers.
# First two are the displacements along x and y ($\Delta x$, $\Delta y$)
# And the last integer represents the state of the pen - 1 if it's touching
# the paper and 0 otherwise.
#
# We iterate through each of the sequences
# We iterate through each of the sequences and filter
for seq in dataset:
# Filter if the length of the the sequence of strokes is within our range
if 10 < len(seq) <= max_seq_length:
@ -62,103 +63,163 @@ class StrokesDataset(Dataset):
seq = np.array(seq, dtype=np.float32)
data.append(seq)
# We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation.
# This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined.
# We then calculate the scaling factor which is the
# standard deviation of ($\Delta x$, $\Delta y$) combined.
# Paper notes that the mean is not adjusted for simplicity,
# since the mean is anyway close to $0$.
if scale is None:
scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
self.scale = scale
for s in data:
# Adjust by standard deviation
s[:, 0:2] /= scale
# Get the longest sequence length among all sequences
longest_seq_len = max([len(seq) for seq in data])
# Initialize PyTorch data array
# We initialize PyTorch data array with two extra steps for start-of-sequence (sos)
# and end-of-sequence (eos).
# Each step is a vector $(\Delta x, \Delta y, p_1, p_2, p_3)$.
# Only one of $p_1, p_2, p_3$ is $1$ and the others are $0$.
# They represent *pen down*, *pen up* and *end-of-sequence* in that order.
# $p_1$ is $1$ if the pen touches the paper in the next step.
# $p_2$ is $1$ if the pen doesn't touch the paper in the next step.
# $p_2$ is $1$ if it is the end of the drawing.
self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
# Initialize mask array. Mask has an extra step because the model predicts
# end of sequence at the end.
# The mask array is needs only one extra-step since it is for the outputs of the
# decoder, which takes in `data[:-1]` and predicts next step.
self.mask = torch.zeros(len(data), longest_seq_len + 1)
for i, seq in enumerate(data):
seq = torch.from_numpy(seq)
len_seq = len(seq)
# set x, y
self.data[i, 1:len_seq + 1, :2] = seq[:, :2]
# set pen status
# Scale and set $\Delta x, \Delta y$
self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
# $p_1$
self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
# $p_2$
self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
# $p_3$
self.data[i, len_seq + 1:, 4] = 1
# Mask is on until end of sequence
self.mask[i, :len_seq + 1] = 1
# Start-of-sequence is $(0, 0, 1, 0, 0)
self.data[:, 0, 2] = 1
def __len__(self):
"""Size of the dataset"""
return len(self.data)
def __getitem__(self, idx: int):
"""Get a sample"""
return self.data[idx], self.mask[idx]
class EncoderRNN(Module):
"""
## Encoder module
This consists of a bidirectional LSTM
"""
def __init__(self, d_z: int, enc_hidden_size: int):
super().__init__()
# Create a bidirectional LSTM takes a sequence of
# $(\Delta x, \Delta y, p_1, p_2, p_3)$ as input.
self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
# Head to get $\mu$
self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
# Head to get $\hat{\sigma}$
self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
def __call__(self, inputs: torch.Tensor, state=None):
# The hidden state of the bidirectional LSTM is the concatenation of the
# output of the last token in the forward direction and
# and first token in the reverse direction.
# Which is what we want.
# $$h_{\rightarrow} = encode_{\rightarrow}(S),
# h_{\leftarrow} = encode←_{\leftarrow}(S_{reverse}),
# h = [h_{\rightarrow}; h_{\leftarrow}]$$
_, (hidden, cell) = self.lstm(inputs.float(), state)
# The state has shape `[2, batch_size, hidden_size]`
# where the first dimension is the direction.
# We rearrange it to get $h = [h_{\rightarrow}; h_{\leftarrow}]$
hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
# $\mu$
mu = self.mu_head(hidden)
# $\hat{\sigma}$
sigma_hat = self.sigma_head(hidden)
# $\sigma = \exp(\frac{\hat{\sigma}}{2})$
sigma = torch.exp(sigma_hat / 2.)
z_size = mu.size()
z = mu + sigma * torch.normal(mu.new_zeros(z_size), mu.new_ones(z_size))
# Sample $z = \mu + \sigma \cdot \mathcal{N}(0, I)$
z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
return z, mu, sigma_hat
class DecoderRNN(Module):
def __init__(self, d_z: int, dec_hidden_size: int, n_mixtures: int):
super().__init__()
self.dec_hidden_size = dec_hidden_size
self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
"""
## Encoder module
This consists of a LSTM
"""
def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
super().__init__()
# LSTM takes $[z; (\Delta x, \Delta y, p_1, p_2, p_3)$ as input
self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
self.mixtures = nn.Linear(dec_hidden_size, 6 * n_mixtures)
# Initial state of the LSTM is $[h_0; c_0] = \tanh(W_{z}z + b_z)$.
# `init_state` is the linear transformation for this
self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
# This layer produces outputs for each of of the `n_distributions`.
# Each distribution needs six parameters
# $(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$
self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
# This head is for the logits $(\hat{q_1}, \hat{q_2}, \hat{q_3})$
self.q_head = nn.Linear(dec_hidden_size, 3)
self.n_mixtures = n_mixtures
# This is to calculate $\log(q_k)$ where
# $$q_k = \frac{\exp(\hat{q_k}}{\sum_{j = 1}^3 \exp(\hat{q_j}}$$
self.q_log_softmax = nn.LogSoftmax(-1)
# These parameters are stored for future reference
self.n_distributions = n_distributions
self.dec_hidden_size = dec_hidden_size
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
# Calculate the initial state
if state is None:
hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())
# $[h_0; c_0] = \tanh(W_{z}z + b_z)$
h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
# `h` and `c` have shapes `[batch_size, lstm_size]`. We want to make them
# to shape `[1, batch_size, lstm_size]` because that's the shape used in LSTM.
state = (h.unsqueeze(0), c.unsqueeze(0))
outputs, (hidden, cell) = self.lstm(x, state)
# Run the LSTM
outputs, state = self.lstm(x, state)
# Get $\log(q)$
q_logits = self.q_log_softmax(self.q_head(outputs))
# Get $(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}},
# \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$.
# `torch.split` splits the output into 6 tensors of size `self.n_distribution`
# across dimension `2`.
pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
torch.split(self.mixtures(outputs), self.n_mixtures, 2)
torch.split(self.mixtures(outputs), self.n_distributions, 2)
# Create a bivariate gaussian mixture
dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
return dist, q_logits, (hidden, cell)
return dist, q_logits, state
class ReconstructionLoss(Module):
def __call__(self, mask: torch.Tensor, target: torch.Tensor,
dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
pi, mix = dist.get_distribution()
xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_mixtures, -1)
xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
@ -244,7 +305,7 @@ class Configs(TrainValidConfigs):
batch_size = 100
d_z = 128
n_mixtures = 20
n_distributions = 20
kl_div_loss_weight = 0.5
grad_clip = 1.
@ -264,7 +325,7 @@ class Configs(TrainValidConfigs):
Configs.valid_dataset, Configs.valid_loader])
def setup_all(self: Configs):
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device)
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
self.optimizer = OptimizerConfigs()
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
@ -359,7 +420,7 @@ class BivariateGaussianMixture:
self.rho_xy = rho_xy
@property
def n_mixtures(self):
def n_distributions(self):
return self.pi_logits.shape[-1]
def set_temperature(self, temperature: float):