mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 10:48:49 +08:00 
			
		
		
		
	🧹 cleanup
This commit is contained in:
		| @ -5,7 +5,7 @@ This is an annotated implementation of the paper | |||||||
| Download data from [Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset). | Download data from [Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset). | ||||||
| There is a link to download `npz` files in *Sketch-RNN QuickDraw Dataset* section of the readme. | There is a link to download `npz` files in *Sketch-RNN QuickDraw Dataset* section of the readme. | ||||||
| Place the downloaded `npz` file(s) in `data/sketch` folder. | Place the downloaded `npz` file(s) in `data/sketch` folder. | ||||||
| This code is configured to use `bycle` dataset. | This code is configured to use `bicycle` dataset. | ||||||
| You can change this in configurations. | You can change this in configurations. | ||||||
|  |  | ||||||
| ### Acknowledgements | ### Acknowledgements | ||||||
| @ -14,7 +14,7 @@ Took help from [PyTorch Sketch RNN)(https://github.com/alexis-jacq/Pytorch-Sketc | |||||||
| """ | """ | ||||||
|  |  | ||||||
| import math | import math | ||||||
| from typing import Optional | from typing import Optional, Tuple | ||||||
|  |  | ||||||
| import einops | import einops | ||||||
| import numpy as np | import numpy as np | ||||||
| @ -41,12 +41,7 @@ class StrokesDataset(Dataset): | |||||||
|     This class load and pre-process the data. |     This class load and pre-process the data. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, dataset_name: str, max_seq_length): |     def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None): | ||||||
|         # `npz` file path is `data/sketch/[DATASET NAME].npz` |  | ||||||
|         path = lab.get_data_path() / 'sketch' / f'{dataset_name}.npz' |  | ||||||
|         # Load the numpy file. |  | ||||||
|         dataset = np.load(str(path), encoding='latin1', allow_pickle=True) |  | ||||||
|  |  | ||||||
|         # Filter and convert training sequences to floats. |         # Filter and convert training sequences to floats. | ||||||
|         data = [] |         data = [] | ||||||
|         # `dataset['train']` is a list of numpy arrays of shape [seq_len, 3]. |         # `dataset['train']` is a list of numpy arrays of shape [seq_len, 3]. | ||||||
| @ -54,10 +49,10 @@ class StrokesDataset(Dataset): | |||||||
|         # 3 integers. |         # 3 integers. | ||||||
|         # First two are the displacements along x and y ($\Delta x$, $\Delta y$) |         # First two are the displacements along x and y ($\Delta x$, $\Delta y$) | ||||||
|         # And the last integer represents the state of the pen - 1 if it's touching |         # And the last integer represents the state of the pen - 1 if it's touching | ||||||
|         # the paper and 0 otherwise, and 2 if it's end of sequence. |         # the paper and 0 otherwise. | ||||||
|         # |         # | ||||||
|         # We iterate through each of the sequences |         # We iterate through each of the sequences | ||||||
|         for seq in dataset['train']: |         for seq in dataset: | ||||||
|             # Filter if the length of the the sequence of strokes is within our range |             # Filter if the length of the the sequence of strokes is within our range | ||||||
|             if 10 < len(seq) <= max_seq_length: |             if 10 < len(seq) <= max_seq_length: | ||||||
|                 # Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$ |                 # Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$ | ||||||
| @ -69,17 +64,20 @@ class StrokesDataset(Dataset): | |||||||
|  |  | ||||||
|         # We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation. |         # We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation. | ||||||
|         # This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined. |         # This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined. | ||||||
|         # Paper notes that the mean is not adjusted for simplicity since the mean is anyway close to $0$. |         # Paper notes that the mean is not adjusted for simplicity, | ||||||
|         std = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data])) |         # since the mean is anyway close to $0$. | ||||||
|  |         if scale is None: | ||||||
|  |             scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data])) | ||||||
|  |         self.scale = scale | ||||||
|         for s in data: |         for s in data: | ||||||
|             # Adjust by standard deviation |             # Adjust by standard deviation | ||||||
|             s[:, 0:2] /= std |             s[:, 0:2] /= scale | ||||||
|  |  | ||||||
|         # Get the longest sequence length among all sequences |         # Get the longest sequence length among all sequences | ||||||
|         longest_seq_len = max([len(seq) for seq in data]) |         longest_seq_len = max([len(seq) for seq in data]) | ||||||
|  |  | ||||||
|         # Initialize PyTorch data array |         # Initialize PyTorch data array | ||||||
|         self.data = torch.zeros(len(data), longest_seq_len, 5, dtype=torch.float) |         self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float) | ||||||
|         # Initialize mask array. Mask has an extra step because the model predicts |         # Initialize mask array. Mask has an extra step because the model predicts | ||||||
|         # end of sequence at the end. |         # end of sequence at the end. | ||||||
|         self.mask = torch.zeros(len(data), longest_seq_len + 1) |         self.mask = torch.zeros(len(data), longest_seq_len + 1) | ||||||
| @ -88,21 +86,20 @@ class StrokesDataset(Dataset): | |||||||
|             seq = torch.from_numpy(seq) |             seq = torch.from_numpy(seq) | ||||||
|             len_seq = len(seq) |             len_seq = len(seq) | ||||||
|             # set x, y |             # set x, y | ||||||
|             self.data[i, :len_seq, :2] = seq[:, :2] |             self.data[i, 1:len_seq + 1, :2] = seq[:, :2] | ||||||
|             # set pen status |             # set pen status | ||||||
|             self.data[i, :len_seq - 1, 2] = 1 - seq[:-1, 2] |             self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2] | ||||||
|             self.data[i, :len_seq - 1, 3] = seq[:-1, 2] |             self.data[i, 1:len_seq + 1, 3] = seq[:, 2] | ||||||
|             self.data[i, len_seq - 1:, 4] = 1 |             self.data[i, len_seq + 1:, 4] = 1 | ||||||
|             self.mask[i, :len_seq] = 1 |             self.mask[i, :len_seq + 1] = 1 | ||||||
|  |  | ||||||
|         eos = torch.zeros(len(data), 1, 5) |         self.data[:, 0, 2] = 1 | ||||||
|         self.target = torch.cat([self.data, eos], 1) |  | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self.data) |         return len(self.data) | ||||||
|  |  | ||||||
|     def __getitem__(self, idx: int): |     def __getitem__(self, idx: int): | ||||||
|         return self.data[idx], self.target[idx], self.mask[idx] |         return self.data[idx], self.mask[idx] | ||||||
|  |  | ||||||
|  |  | ||||||
| class EncoderRNN(Module): | class EncoderRNN(Module): | ||||||
| @ -139,22 +136,17 @@ class DecoderRNN(Module): | |||||||
|         self.n_mixtures = n_mixtures |         self.n_mixtures = n_mixtures | ||||||
|         self.q_log_softmax = nn.LogSoftmax(-1) |         self.q_log_softmax = nn.LogSoftmax(-1) | ||||||
|  |  | ||||||
|     def __call__(self, inputs, z: torch.Tensor, state=None): |     def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): | ||||||
|         if state is None: |         if state is None: | ||||||
|             hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1) |             hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1) | ||||||
|             state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) |             state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) | ||||||
|  |  | ||||||
|         outputs, (hidden, cell) = self.lstm(inputs, state) |         _, (hidden, cell) = self.lstm(x, state) | ||||||
|         if not self.training: |  | ||||||
|             # We dont have to change shape since hidden has shape [1, batch_size, hidden_size] |  | ||||||
|             # and outputs is of shape [seq_len, batch_size, hidden_size], since we are |  | ||||||
|             # using a single direction one layer lstm |  | ||||||
|             outputs = hidden |  | ||||||
|  |  | ||||||
|         q_logits = self.q_log_softmax(self.q_head(outputs)) |         q_logits = self.q_log_softmax(self.q_head(hidden)) | ||||||
|  |  | ||||||
|         pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \ |         pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \ | ||||||
|             torch.split(self.mixtures(outputs), self.n_mixtures, 2) |             torch.split(self.mixtures(hidden), self.n_mixtures, 2) | ||||||
|  |  | ||||||
|         dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y, |         dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y, | ||||||
|                                         torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy)) |                                         torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy)) | ||||||
| @ -185,30 +177,32 @@ class Sampler: | |||||||
|  |  | ||||||
|     def sample(self, data: torch.Tensor, temperature: float): |     def sample(self, data: torch.Tensor, temperature: float): | ||||||
|         z, _, _ = self.encoder(data) |         z, _, _ = self.encoder(data) | ||||||
|         sos = data.new_tensor([[[0, 0, 1, 0, 0]]]) |         sos = data.new_tensor([0, 0, 1, 0, 0]) | ||||||
|         seq_len = len(data) |         seq_len = len(data) | ||||||
|         s = sos |         s = sos | ||||||
|         seq_x = [] |         seq = [s] | ||||||
|         seq_y = [] |  | ||||||
|         seq_z = [] |  | ||||||
|         state = None |         state = None | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             for i in range(seq_len): |             for i in range(seq_len): | ||||||
|                 data = torch.cat([s, z.unsqueeze(0)], 2) |                 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2) | ||||||
|                 dist, q_logits, state = self.decoder(data, z, state) |                 dist, q_logits, state = self.decoder(data, z, state) | ||||||
|                 s, xy, pen_down, eos = self._sample_step(dist, q_logits, temperature) |                 s = self._sample_step(dist, q_logits, temperature) | ||||||
|                 seq_x.append(xy[0].item()) |                 seq.append(s) | ||||||
|                 seq_y.append(xy[1].item()) |                 if s[4] == 1: | ||||||
|                 seq_z.append(pen_down) |                     print(i) | ||||||
|                 if eos: |  | ||||||
|                     break |                     break | ||||||
|  |  | ||||||
|         x_sample = np.cumsum(seq_x, 0) |         seq = torch.stack(seq) | ||||||
|         y_sample = np.cumsum(seq_y, 0) |  | ||||||
|         z_sample = np.array(seq_z) |  | ||||||
|         sequence = np.stack([x_sample, y_sample, z_sample]).T |  | ||||||
|  |  | ||||||
|         strokes = np.split(sequence, np.where(sequence[:, 2] > 0)[0] + 1) |         self.plot(seq) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def plot(seq: torch.Tensor): | ||||||
|  |         seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0) | ||||||
|  |         seq[:, 2] = seq[:, 3] == 1 | ||||||
|  |         seq = seq[:, 0:3].detach().cpu().numpy() | ||||||
|  |  | ||||||
|  |         strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1) | ||||||
|         for s in strokes: |         for s in strokes: | ||||||
|             plt.plot(s[:, 0], -s[:, 1]) |             plt.plot(s[:, 0], -s[:, 1]) | ||||||
|         plt.axis('off') |         plt.axis('off') | ||||||
| @ -224,10 +218,10 @@ class Sampler: | |||||||
|         q_idx = q.sample()[0, 0] |         q_idx = q.sample()[0, 0] | ||||||
|  |  | ||||||
|         xy = mix.sample()[0, 0, idx] |         xy = mix.sample()[0, 0, idx] | ||||||
|         next_pos = q_logits.new_zeros(1, 1, 5) |         next_pos = q_logits.new_zeros(5) | ||||||
|         next_pos[0, 0, :2] = xy |         next_pos[:2] = xy | ||||||
|         next_pos[0, 0, q_idx + 2] = 1 |         next_pos[q_idx + 2] = 1 | ||||||
|         return next_pos, xy, q_idx == 1, q_idx == 2 |         return next_pos | ||||||
|  |  | ||||||
|  |  | ||||||
| class Configs(TrainValidConfigs): | class Configs(TrainValidConfigs): | ||||||
| @ -238,8 +232,10 @@ class Configs(TrainValidConfigs): | |||||||
|     sampler: Sampler |     sampler: Sampler | ||||||
|  |  | ||||||
|     dataset_name: str |     dataset_name: str | ||||||
|     dataset: StrokesDataset = 'setup_all' |  | ||||||
|     train_loader = 'setup_all' |     train_loader = 'setup_all' | ||||||
|  |     valid_loader = 'setup_all' | ||||||
|  |     train_dataset: StrokesDataset | ||||||
|  |     valid_dataset: StrokesDataset | ||||||
|  |  | ||||||
|     enc_hidden_size = 256 |     enc_hidden_size = 256 | ||||||
|     dec_hidden_size = 512 |     dec_hidden_size = 512 | ||||||
| @ -256,18 +252,17 @@ class Configs(TrainValidConfigs): | |||||||
|     temperature = 0.4 |     temperature = 0.4 | ||||||
|     max_seq_length = 200 |     max_seq_length = 200 | ||||||
|  |  | ||||||
|     validator = None |  | ||||||
|     valid_loader = None |  | ||||||
|     epochs = 100 |     epochs = 100 | ||||||
|  |  | ||||||
|     def sample(self): |     def sample(self): | ||||||
|         data, *_ = self.dataset[np.random.choice(len(self.dataset))] |         data, *_ = self.train_dataset[np.random.choice(len(self.train_dataset))] | ||||||
|         data = data.unsqueeze(1).to(self.device) |         data = data.unsqueeze(1).to(self.device) | ||||||
|         self.sampler.sample(data, self.temperature) |         self.sampler.sample(data, self.temperature) | ||||||
|  |  | ||||||
|  |  | ||||||
| @setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler, | @setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler, | ||||||
|         Configs.dataset, Configs.train_loader]) |         Configs.train_dataset, Configs.train_loader, | ||||||
|  |         Configs.valid_dataset, Configs.valid_loader]) | ||||||
| def setup_all(self: Configs): | def setup_all(self: Configs): | ||||||
|     self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) |     self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) | ||||||
|     self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device) |     self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device) | ||||||
| @ -276,9 +271,16 @@ def setup_all(self: Configs): | |||||||
|     self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) |     self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) | ||||||
|  |  | ||||||
|     self.sampler = Sampler(self.encoder, self.decoder) |     self.sampler = Sampler(self.encoder, self.decoder) | ||||||
|  |     # `npz` file path is `data/sketch/[DATASET NAME].npz` | ||||||
|  |     path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' | ||||||
|  |     # Load the numpy file. | ||||||
|  |     dataset = np.load(str(path), encoding='latin1', allow_pickle=True) | ||||||
|  |  | ||||||
|     self.dataset = StrokesDataset(self.dataset_name, self.max_seq_length) |     self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) | ||||||
|     self.train_loader = DataLoader(self.dataset, self.batch_size, shuffle=True) |     self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) | ||||||
|  |  | ||||||
|  |     self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) | ||||||
|  |     self.valid_loader = DataLoader(self.valid_dataset, self.batch_size, shuffle=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class StrokesBatchStep(BatchStepProtocol): | class StrokesBatchStep(BatchStepProtocol): | ||||||
| @ -296,39 +298,34 @@ class StrokesBatchStep(BatchStepProtocol): | |||||||
|         hook_model_outputs(self.encoder, 'encoder') |         hook_model_outputs(self.encoder, 'encoder') | ||||||
|         hook_model_outputs(self.decoder, 'decoder') |         hook_model_outputs(self.decoder, 'decoder') | ||||||
|         tracker.set_scalar("loss.*", True) |         tracker.set_scalar("loss.*", True) | ||||||
|         tracker.set_image("generated", True) |  | ||||||
|  |  | ||||||
|     def prepare_for_iteration(self): |     def prepare_for_iteration(self): | ||||||
|         if MODE_STATE.is_train: |         if MODE_STATE.is_train: | ||||||
|             self.encoder.train() |             self.encoder.train() | ||||||
|             self.decoder.train() |             self.decoder.train() | ||||||
|         else: |         else: | ||||||
|             self.encoder.eval() |             self.encoder.train() | ||||||
|             self.decoder.eval() |             self.decoder.train() | ||||||
|  |             # self.encoder.eval() | ||||||
|  |             # self.decoder.eval() | ||||||
|  |  | ||||||
|     def process(self, batch: any, state: any): |     def process(self, batch: any, state: any): | ||||||
|         device = self.encoder.device |         device = self.encoder.device | ||||||
|         data, target, mask = batch |         data, mask = batch | ||||||
|         data = data.to(device).transpose(0, 1) |         data = data.to(device).transpose(0, 1) | ||||||
|         target = target.to(device).transpose(0, 1) |  | ||||||
|         mask = mask.to(device).transpose(0, 1) |         mask = mask.to(device).transpose(0, 1) | ||||||
|         batch_size = data.shape[1] |  | ||||||
|         seq_len = data.shape[0] |  | ||||||
|  |  | ||||||
|         with monit.section("encoder"): |         with monit.section("encoder"): | ||||||
|             z, mu, sigma = self.encoder(data) |             z, mu, sigma = self.encoder(data) | ||||||
|  |  | ||||||
|         with monit.section("decoder"): |         with monit.section("decoder"): | ||||||
|             sos = torch.stack([torch.tensor([0, 0, 1, 0, 0])] * batch_size). \ |             z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1) | ||||||
|                 unsqueeze(0).to(device) |             inputs = torch.cat([data[:-1], z_stack], 2) | ||||||
|             batch_init = torch.cat([sos, data], 0) |             dist, q_logits, _ = self.decoder(inputs, z, None) | ||||||
|             z_stack = torch.stack([z] * (seq_len + 1)) |  | ||||||
|             inputs = torch.cat([batch_init, z_stack], 2) |  | ||||||
|             dist, q_logits, _ = self.decoder(inputs, z) |  | ||||||
|  |  | ||||||
|         with monit.section('loss'): |         with monit.section('loss'): | ||||||
|             kl_loss = self.kl_div_loss(sigma, mu) |             kl_loss = self.kl_div_loss(sigma, mu) | ||||||
|             reconstruction_loss = self.reconstruction_loss(mask, target, dist, q_logits) |             reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits) | ||||||
|             loss = self.kl_div_loss_weight * kl_loss + reconstruction_loss |             loss = self.kl_div_loss_weight * kl_loss + reconstruction_loss | ||||||
|  |  | ||||||
|             tracker.add("loss.kl.", kl_loss) |             tracker.add("loss.kl.", kl_loss) | ||||||
| @ -346,8 +343,6 @@ class StrokesBatchStep(BatchStepProtocol): | |||||||
|                 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) |                 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) | ||||||
|                 self.optimizer.step() |                 self.optimizer.step() | ||||||
|  |  | ||||||
|             # tracker.add('generated', generated_images[0:5]) |  | ||||||
|  |  | ||||||
|         return {'samples': len(data)}, None |         return {'samples': len(data)}, None | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -395,7 +390,8 @@ def main(): | |||||||
|     experiment.configs(configs, { |     experiment.configs(configs, { | ||||||
|         'optimizer.optimizer': 'Adam', |         'optimizer.optimizer': 'Adam', | ||||||
|         'optimizer.learning_rate': 1e-3, |         'optimizer.learning_rate': 1e-3, | ||||||
|         'dataset_name': 'bicycle' |         'dataset_name': 'bicycle', | ||||||
|  |         'inner_iterations': 10 | ||||||
|     }, 'run') |     }, 'run') | ||||||
|     experiment.start() |     experiment.start() | ||||||
|  |  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri