mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 17:41:37 +08:00
📚 encoder/decoder
This commit is contained in:
@ -42,16 +42,17 @@ class StrokesDataset(Dataset):
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
|
||||
# Filter and convert training sequences to floats.
|
||||
"""
|
||||
`dataset` is a list of numpy arrays of shape [seq_len, 3].
|
||||
It is a sequence of strokes, and each stroke is represented by
|
||||
3 integers.
|
||||
First two are the displacements along x and y ($\Delta x$, $\Delta y$)
|
||||
And the last integer represents the state of the pen - $1$ if it's touching
|
||||
the paper and $0$ otherwise.
|
||||
"""
|
||||
|
||||
data = []
|
||||
# `dataset['train']` is a list of numpy arrays of shape [seq_len, 3].
|
||||
# It is a sequence of strokes, and each stroke is represented by
|
||||
# 3 integers.
|
||||
# First two are the displacements along x and y ($\Delta x$, $\Delta y$)
|
||||
# And the last integer represents the state of the pen - 1 if it's touching
|
||||
# the paper and 0 otherwise.
|
||||
#
|
||||
# We iterate through each of the sequences
|
||||
# We iterate through each of the sequences and filter
|
||||
for seq in dataset:
|
||||
# Filter if the length of the the sequence of strokes is within our range
|
||||
if 10 < len(seq) <= max_seq_length:
|
||||
@ -62,103 +63,163 @@ class StrokesDataset(Dataset):
|
||||
seq = np.array(seq, dtype=np.float32)
|
||||
data.append(seq)
|
||||
|
||||
# We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation.
|
||||
# This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined.
|
||||
# We then calculate the scaling factor which is the
|
||||
# standard deviation of ($\Delta x$, $\Delta y$) combined.
|
||||
# Paper notes that the mean is not adjusted for simplicity,
|
||||
# since the mean is anyway close to $0$.
|
||||
if scale is None:
|
||||
scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
|
||||
self.scale = scale
|
||||
for s in data:
|
||||
# Adjust by standard deviation
|
||||
s[:, 0:2] /= scale
|
||||
|
||||
# Get the longest sequence length among all sequences
|
||||
longest_seq_len = max([len(seq) for seq in data])
|
||||
|
||||
# Initialize PyTorch data array
|
||||
# We initialize PyTorch data array with two extra steps for start-of-sequence (sos)
|
||||
# and end-of-sequence (eos).
|
||||
# Each step is a vector $(\Delta x, \Delta y, p_1, p_2, p_3)$.
|
||||
# Only one of $p_1, p_2, p_3$ is $1$ and the others are $0$.
|
||||
# They represent *pen down*, *pen up* and *end-of-sequence* in that order.
|
||||
# $p_1$ is $1$ if the pen touches the paper in the next step.
|
||||
# $p_2$ is $1$ if the pen doesn't touch the paper in the next step.
|
||||
# $p_2$ is $1$ if it is the end of the drawing.
|
||||
self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
|
||||
# Initialize mask array. Mask has an extra step because the model predicts
|
||||
# end of sequence at the end.
|
||||
# The mask array is needs only one extra-step since it is for the outputs of the
|
||||
# decoder, which takes in `data[:-1]` and predicts next step.
|
||||
self.mask = torch.zeros(len(data), longest_seq_len + 1)
|
||||
|
||||
for i, seq in enumerate(data):
|
||||
seq = torch.from_numpy(seq)
|
||||
len_seq = len(seq)
|
||||
# set x, y
|
||||
self.data[i, 1:len_seq + 1, :2] = seq[:, :2]
|
||||
# set pen status
|
||||
# Scale and set $\Delta x, \Delta y$
|
||||
self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
|
||||
# $p_1$
|
||||
self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
|
||||
# $p_2$
|
||||
self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
|
||||
# $p_3$
|
||||
self.data[i, len_seq + 1:, 4] = 1
|
||||
# Mask is on until end of sequence
|
||||
self.mask[i, :len_seq + 1] = 1
|
||||
|
||||
# Start-of-sequence is $(0, 0, 1, 0, 0)
|
||||
self.data[:, 0, 2] = 1
|
||||
|
||||
def __len__(self):
|
||||
"""Size of the dataset"""
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
"""Get a sample"""
|
||||
return self.data[idx], self.mask[idx]
|
||||
|
||||
|
||||
class EncoderRNN(Module):
|
||||
"""
|
||||
## Encoder module
|
||||
|
||||
This consists of a bidirectional LSTM
|
||||
"""
|
||||
def __init__(self, d_z: int, enc_hidden_size: int):
|
||||
super().__init__()
|
||||
# Create a bidirectional LSTM takes a sequence of
|
||||
# $(\Delta x, \Delta y, p_1, p_2, p_3)$ as input.
|
||||
self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
|
||||
# Head to get $\mu$
|
||||
self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
|
||||
# Head to get $\hat{\sigma}$
|
||||
self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
|
||||
|
||||
def __call__(self, inputs: torch.Tensor, state=None):
|
||||
# The hidden state of the bidirectional LSTM is the concatenation of the
|
||||
# output of the last token in the forward direction and
|
||||
# and first token in the reverse direction.
|
||||
# Which is what we want.
|
||||
# $$h_{\rightarrow} = encode_{\rightarrow}(S),
|
||||
# h_{\leftarrow} = encode←_{\leftarrow}(S_{reverse}),
|
||||
# h = [h_{\rightarrow}; h_{\leftarrow}]$$
|
||||
_, (hidden, cell) = self.lstm(inputs.float(), state)
|
||||
# The state has shape `[2, batch_size, hidden_size]`
|
||||
# where the first dimension is the direction.
|
||||
# We rearrange it to get $h = [h_{\rightarrow}; h_{\leftarrow}]$
|
||||
hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
|
||||
|
||||
# $\mu$
|
||||
mu = self.mu_head(hidden)
|
||||
# $\hat{\sigma}$
|
||||
sigma_hat = self.sigma_head(hidden)
|
||||
# $\sigma = \exp(\frac{\hat{\sigma}}{2})$
|
||||
sigma = torch.exp(sigma_hat / 2.)
|
||||
|
||||
z_size = mu.size()
|
||||
z = mu + sigma * torch.normal(mu.new_zeros(z_size), mu.new_ones(z_size))
|
||||
# Sample $z = \mu + \sigma \cdot \mathcal{N}(0, I)$
|
||||
z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
|
||||
|
||||
return z, mu, sigma_hat
|
||||
|
||||
|
||||
class DecoderRNN(Module):
|
||||
def __init__(self, d_z: int, dec_hidden_size: int, n_mixtures: int):
|
||||
super().__init__()
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
|
||||
"""
|
||||
## Encoder module
|
||||
|
||||
This consists of a LSTM
|
||||
"""
|
||||
def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
|
||||
super().__init__()
|
||||
# LSTM takes $[z; (\Delta x, \Delta y, p_1, p_2, p_3)$ as input
|
||||
self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
|
||||
|
||||
self.mixtures = nn.Linear(dec_hidden_size, 6 * n_mixtures)
|
||||
# Initial state of the LSTM is $[h_0; c_0] = \tanh(W_{z}z + b_z)$.
|
||||
# `init_state` is the linear transformation for this
|
||||
self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
|
||||
|
||||
# This layer produces outputs for each of of the `n_distributions`.
|
||||
# Each distribution needs six parameters
|
||||
# $(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}}, \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$
|
||||
self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
|
||||
|
||||
# This head is for the logits $(\hat{q_1}, \hat{q_2}, \hat{q_3})$
|
||||
self.q_head = nn.Linear(dec_hidden_size, 3)
|
||||
self.n_mixtures = n_mixtures
|
||||
# This is to calculate $\log(q_k)$ where
|
||||
# $$q_k = \frac{\exp(\hat{q_k}}{\sum_{j = 1}^3 \exp(\hat{q_j}}$$
|
||||
self.q_log_softmax = nn.LogSoftmax(-1)
|
||||
|
||||
# These parameters are stored for future reference
|
||||
self.n_distributions = n_distributions
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
|
||||
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
# Calculate the initial state
|
||||
if state is None:
|
||||
hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
|
||||
state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())
|
||||
# $[h_0; c_0] = \tanh(W_{z}z + b_z)$
|
||||
h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
|
||||
# `h` and `c` have shapes `[batch_size, lstm_size]`. We want to make them
|
||||
# to shape `[1, batch_size, lstm_size]` because that's the shape used in LSTM.
|
||||
state = (h.unsqueeze(0), c.unsqueeze(0))
|
||||
|
||||
outputs, (hidden, cell) = self.lstm(x, state)
|
||||
# Run the LSTM
|
||||
outputs, state = self.lstm(x, state)
|
||||
|
||||
# Get $\log(q)$
|
||||
q_logits = self.q_log_softmax(self.q_head(outputs))
|
||||
|
||||
# Get $(\hat{\Pi_i}, \mu_{x_i}, \mu_{y_i}, \hat{\sigma_{x_i}},
|
||||
# \hat{\sigma_{y_i}} \hat{\rho_{xy_i}})$.
|
||||
# `torch.split` splits the output into 6 tensors of size `self.n_distribution`
|
||||
# across dimension `2`.
|
||||
pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
|
||||
torch.split(self.mixtures(outputs), self.n_mixtures, 2)
|
||||
torch.split(self.mixtures(outputs), self.n_distributions, 2)
|
||||
|
||||
# Create a bivariate gaussian mixture
|
||||
dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
|
||||
torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
|
||||
|
||||
return dist, q_logits, (hidden, cell)
|
||||
return dist, q_logits, state
|
||||
|
||||
|
||||
class ReconstructionLoss(Module):
|
||||
def __call__(self, mask: torch.Tensor, target: torch.Tensor,
|
||||
dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
|
||||
pi, mix = dist.get_distribution()
|
||||
xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_mixtures, -1)
|
||||
xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
|
||||
probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
|
||||
loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
|
||||
loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
|
||||
@ -244,7 +305,7 @@ class Configs(TrainValidConfigs):
|
||||
batch_size = 100
|
||||
|
||||
d_z = 128
|
||||
n_mixtures = 20
|
||||
n_distributions = 20
|
||||
|
||||
kl_div_loss_weight = 0.5
|
||||
grad_clip = 1.
|
||||
@ -264,7 +325,7 @@ class Configs(TrainValidConfigs):
|
||||
Configs.valid_dataset, Configs.valid_loader])
|
||||
def setup_all(self: Configs):
|
||||
self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
|
||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device)
|
||||
self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
|
||||
|
||||
self.optimizer = OptimizerConfigs()
|
||||
self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
@ -359,7 +420,7 @@ class BivariateGaussianMixture:
|
||||
self.rho_xy = rho_xy
|
||||
|
||||
@property
|
||||
def n_mixtures(self):
|
||||
def n_distributions(self):
|
||||
return self.pi_logits.shape[-1]
|
||||
|
||||
def set_temperature(self, temperature: float):
|
||||
|
Reference in New Issue
Block a user