mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-30 10:18:50 +08:00 
			
		
		
		
	🧹 cleanup
This commit is contained in:
		| @ -5,7 +5,7 @@ This is an annotated implementation of the paper | ||||
| Download data from [Quick, Draw! Dataset](https://github.com/googlecreativelab/quickdraw-dataset). | ||||
| There is a link to download `npz` files in *Sketch-RNN QuickDraw Dataset* section of the readme. | ||||
| Place the downloaded `npz` file(s) in `data/sketch` folder. | ||||
| This code is configured to use `bycle` dataset. | ||||
| This code is configured to use `bicycle` dataset. | ||||
| You can change this in configurations. | ||||
|  | ||||
| ### Acknowledgements | ||||
| @ -14,7 +14,7 @@ Took help from [PyTorch Sketch RNN)(https://github.com/alexis-jacq/Pytorch-Sketc | ||||
| """ | ||||
|  | ||||
| import math | ||||
| from typing import Optional | ||||
| from typing import Optional, Tuple | ||||
|  | ||||
| import einops | ||||
| import numpy as np | ||||
| @ -41,12 +41,7 @@ class StrokesDataset(Dataset): | ||||
|     This class load and pre-process the data. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dataset_name: str, max_seq_length): | ||||
|         # `npz` file path is `data/sketch/[DATASET NAME].npz` | ||||
|         path = lab.get_data_path() / 'sketch' / f'{dataset_name}.npz' | ||||
|         # Load the numpy file. | ||||
|         dataset = np.load(str(path), encoding='latin1', allow_pickle=True) | ||||
|  | ||||
|     def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None): | ||||
|         # Filter and convert training sequences to floats. | ||||
|         data = [] | ||||
|         # `dataset['train']` is a list of numpy arrays of shape [seq_len, 3]. | ||||
| @ -54,10 +49,10 @@ class StrokesDataset(Dataset): | ||||
|         # 3 integers. | ||||
|         # First two are the displacements along x and y ($\Delta x$, $\Delta y$) | ||||
|         # And the last integer represents the state of the pen - 1 if it's touching | ||||
|         # the paper and 0 otherwise, and 2 if it's end of sequence. | ||||
|         # the paper and 0 otherwise. | ||||
|         # | ||||
|         # We iterate through each of the sequences | ||||
|         for seq in dataset['train']: | ||||
|         for seq in dataset: | ||||
|             # Filter if the length of the the sequence of strokes is within our range | ||||
|             if 10 < len(seq) <= max_seq_length: | ||||
|                 # Clamp $\Delta x$, $\Delta y$ to $[-1000, 1000]$ | ||||
| @ -69,17 +64,20 @@ class StrokesDataset(Dataset): | ||||
|  | ||||
|         # We then normalize all ($\Delta x$, $\Delta y$) by their standard deviation. | ||||
|         # This calculates the standard deviations for ($\Delta x$, $\Delta y$) combined. | ||||
|         # Paper notes that the mean is not adjusted for simplicity since the mean is anyway close to $0$. | ||||
|         std = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data])) | ||||
|         # Paper notes that the mean is not adjusted for simplicity, | ||||
|         # since the mean is anyway close to $0$. | ||||
|         if scale is None: | ||||
|             scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data])) | ||||
|         self.scale = scale | ||||
|         for s in data: | ||||
|             # Adjust by standard deviation | ||||
|             s[:, 0:2] /= std | ||||
|             s[:, 0:2] /= scale | ||||
|  | ||||
|         # Get the longest sequence length among all sequences | ||||
|         longest_seq_len = max([len(seq) for seq in data]) | ||||
|  | ||||
|         # Initialize PyTorch data array | ||||
|         self.data = torch.zeros(len(data), longest_seq_len, 5, dtype=torch.float) | ||||
|         self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float) | ||||
|         # Initialize mask array. Mask has an extra step because the model predicts | ||||
|         # end of sequence at the end. | ||||
|         self.mask = torch.zeros(len(data), longest_seq_len + 1) | ||||
| @ -88,21 +86,20 @@ class StrokesDataset(Dataset): | ||||
|             seq = torch.from_numpy(seq) | ||||
|             len_seq = len(seq) | ||||
|             # set x, y | ||||
|             self.data[i, :len_seq, :2] = seq[:, :2] | ||||
|             self.data[i, 1:len_seq + 1, :2] = seq[:, :2] | ||||
|             # set pen status | ||||
|             self.data[i, :len_seq - 1, 2] = 1 - seq[:-1, 2] | ||||
|             self.data[i, :len_seq - 1, 3] = seq[:-1, 2] | ||||
|             self.data[i, len_seq - 1:, 4] = 1 | ||||
|             self.mask[i, :len_seq] = 1 | ||||
|             self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2] | ||||
|             self.data[i, 1:len_seq + 1, 3] = seq[:, 2] | ||||
|             self.data[i, len_seq + 1:, 4] = 1 | ||||
|             self.mask[i, :len_seq + 1] = 1 | ||||
|  | ||||
|         eos = torch.zeros(len(data), 1, 5) | ||||
|         self.target = torch.cat([self.data, eos], 1) | ||||
|         self.data[:, 0, 2] = 1 | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self.data) | ||||
|  | ||||
|     def __getitem__(self, idx: int): | ||||
|         return self.data[idx], self.target[idx], self.mask[idx] | ||||
|         return self.data[idx], self.mask[idx] | ||||
|  | ||||
|  | ||||
| class EncoderRNN(Module): | ||||
| @ -139,22 +136,17 @@ class DecoderRNN(Module): | ||||
|         self.n_mixtures = n_mixtures | ||||
|         self.q_log_softmax = nn.LogSoftmax(-1) | ||||
|  | ||||
|     def __call__(self, inputs, z: torch.Tensor, state=None): | ||||
|     def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]): | ||||
|         if state is None: | ||||
|             hidden, cell = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1) | ||||
|             state = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous()) | ||||
|  | ||||
|         outputs, (hidden, cell) = self.lstm(inputs, state) | ||||
|         if not self.training: | ||||
|             # We dont have to change shape since hidden has shape [1, batch_size, hidden_size] | ||||
|             # and outputs is of shape [seq_len, batch_size, hidden_size], since we are | ||||
|             # using a single direction one layer lstm | ||||
|             outputs = hidden | ||||
|         _, (hidden, cell) = self.lstm(x, state) | ||||
|  | ||||
|         q_logits = self.q_log_softmax(self.q_head(outputs)) | ||||
|         q_logits = self.q_log_softmax(self.q_head(hidden)) | ||||
|  | ||||
|         pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \ | ||||
|             torch.split(self.mixtures(outputs), self.n_mixtures, 2) | ||||
|             torch.split(self.mixtures(hidden), self.n_mixtures, 2) | ||||
|  | ||||
|         dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y, | ||||
|                                         torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy)) | ||||
| @ -185,30 +177,32 @@ class Sampler: | ||||
|  | ||||
|     def sample(self, data: torch.Tensor, temperature: float): | ||||
|         z, _, _ = self.encoder(data) | ||||
|         sos = data.new_tensor([[[0, 0, 1, 0, 0]]]) | ||||
|         sos = data.new_tensor([0, 0, 1, 0, 0]) | ||||
|         seq_len = len(data) | ||||
|         s = sos | ||||
|         seq_x = [] | ||||
|         seq_y = [] | ||||
|         seq_z = [] | ||||
|         seq = [s] | ||||
|         state = None | ||||
|         with torch.no_grad(): | ||||
|             for i in range(seq_len): | ||||
|                 data = torch.cat([s, z.unsqueeze(0)], 2) | ||||
|                 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2) | ||||
|                 dist, q_logits, state = self.decoder(data, z, state) | ||||
|                 s, xy, pen_down, eos = self._sample_step(dist, q_logits, temperature) | ||||
|                 seq_x.append(xy[0].item()) | ||||
|                 seq_y.append(xy[1].item()) | ||||
|                 seq_z.append(pen_down) | ||||
|                 if eos: | ||||
|                 s = self._sample_step(dist, q_logits, temperature) | ||||
|                 seq.append(s) | ||||
|                 if s[4] == 1: | ||||
|                     print(i) | ||||
|                     break | ||||
|  | ||||
|         x_sample = np.cumsum(seq_x, 0) | ||||
|         y_sample = np.cumsum(seq_y, 0) | ||||
|         z_sample = np.array(seq_z) | ||||
|         sequence = np.stack([x_sample, y_sample, z_sample]).T | ||||
|         seq = torch.stack(seq) | ||||
|  | ||||
|         strokes = np.split(sequence, np.where(sequence[:, 2] > 0)[0] + 1) | ||||
|         self.plot(seq) | ||||
|  | ||||
|     @staticmethod | ||||
|     def plot(seq: torch.Tensor): | ||||
|         seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0) | ||||
|         seq[:, 2] = seq[:, 3] == 1 | ||||
|         seq = seq[:, 0:3].detach().cpu().numpy() | ||||
|  | ||||
|         strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1) | ||||
|         for s in strokes: | ||||
|             plt.plot(s[:, 0], -s[:, 1]) | ||||
|         plt.axis('off') | ||||
| @ -224,10 +218,10 @@ class Sampler: | ||||
|         q_idx = q.sample()[0, 0] | ||||
|  | ||||
|         xy = mix.sample()[0, 0, idx] | ||||
|         next_pos = q_logits.new_zeros(1, 1, 5) | ||||
|         next_pos[0, 0, :2] = xy | ||||
|         next_pos[0, 0, q_idx + 2] = 1 | ||||
|         return next_pos, xy, q_idx == 1, q_idx == 2 | ||||
|         next_pos = q_logits.new_zeros(5) | ||||
|         next_pos[:2] = xy | ||||
|         next_pos[q_idx + 2] = 1 | ||||
|         return next_pos | ||||
|  | ||||
|  | ||||
| class Configs(TrainValidConfigs): | ||||
| @ -238,8 +232,10 @@ class Configs(TrainValidConfigs): | ||||
|     sampler: Sampler | ||||
|  | ||||
|     dataset_name: str | ||||
|     dataset: StrokesDataset = 'setup_all' | ||||
|     train_loader = 'setup_all' | ||||
|     valid_loader = 'setup_all' | ||||
|     train_dataset: StrokesDataset | ||||
|     valid_dataset: StrokesDataset | ||||
|  | ||||
|     enc_hidden_size = 256 | ||||
|     dec_hidden_size = 512 | ||||
| @ -256,18 +252,17 @@ class Configs(TrainValidConfigs): | ||||
|     temperature = 0.4 | ||||
|     max_seq_length = 200 | ||||
|  | ||||
|     validator = None | ||||
|     valid_loader = None | ||||
|     epochs = 100 | ||||
|  | ||||
|     def sample(self): | ||||
|         data, *_ = self.dataset[np.random.choice(len(self.dataset))] | ||||
|         data, *_ = self.train_dataset[np.random.choice(len(self.train_dataset))] | ||||
|         data = data.unsqueeze(1).to(self.device) | ||||
|         self.sampler.sample(data, self.temperature) | ||||
|  | ||||
|  | ||||
| @setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler, | ||||
|         Configs.dataset, Configs.train_loader]) | ||||
|         Configs.train_dataset, Configs.train_loader, | ||||
|         Configs.valid_dataset, Configs.valid_loader]) | ||||
| def setup_all(self: Configs): | ||||
|     self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) | ||||
|     self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_mixtures).to(self.device) | ||||
| @ -276,9 +271,16 @@ def setup_all(self: Configs): | ||||
|     self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) | ||||
|  | ||||
|     self.sampler = Sampler(self.encoder, self.decoder) | ||||
|     # `npz` file path is `data/sketch/[DATASET NAME].npz` | ||||
|     path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' | ||||
|     # Load the numpy file. | ||||
|     dataset = np.load(str(path), encoding='latin1', allow_pickle=True) | ||||
|  | ||||
|     self.dataset = StrokesDataset(self.dataset_name, self.max_seq_length) | ||||
|     self.train_loader = DataLoader(self.dataset, self.batch_size, shuffle=True) | ||||
|     self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) | ||||
|     self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) | ||||
|  | ||||
|     self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) | ||||
|     self.valid_loader = DataLoader(self.valid_dataset, self.batch_size, shuffle=True) | ||||
|  | ||||
|  | ||||
| class StrokesBatchStep(BatchStepProtocol): | ||||
| @ -296,39 +298,34 @@ class StrokesBatchStep(BatchStepProtocol): | ||||
|         hook_model_outputs(self.encoder, 'encoder') | ||||
|         hook_model_outputs(self.decoder, 'decoder') | ||||
|         tracker.set_scalar("loss.*", True) | ||||
|         tracker.set_image("generated", True) | ||||
|  | ||||
|     def prepare_for_iteration(self): | ||||
|         if MODE_STATE.is_train: | ||||
|             self.encoder.train() | ||||
|             self.decoder.train() | ||||
|         else: | ||||
|             self.encoder.eval() | ||||
|             self.decoder.eval() | ||||
|             self.encoder.train() | ||||
|             self.decoder.train() | ||||
|             # self.encoder.eval() | ||||
|             # self.decoder.eval() | ||||
|  | ||||
|     def process(self, batch: any, state: any): | ||||
|         device = self.encoder.device | ||||
|         data, target, mask = batch | ||||
|         data, mask = batch | ||||
|         data = data.to(device).transpose(0, 1) | ||||
|         target = target.to(device).transpose(0, 1) | ||||
|         mask = mask.to(device).transpose(0, 1) | ||||
|         batch_size = data.shape[1] | ||||
|         seq_len = data.shape[0] | ||||
|  | ||||
|         with monit.section("encoder"): | ||||
|             z, mu, sigma = self.encoder(data) | ||||
|  | ||||
|         with monit.section("decoder"): | ||||
|             sos = torch.stack([torch.tensor([0, 0, 1, 0, 0])] * batch_size). \ | ||||
|                 unsqueeze(0).to(device) | ||||
|             batch_init = torch.cat([sos, data], 0) | ||||
|             z_stack = torch.stack([z] * (seq_len + 1)) | ||||
|             inputs = torch.cat([batch_init, z_stack], 2) | ||||
|             dist, q_logits, _ = self.decoder(inputs, z) | ||||
|             z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1) | ||||
|             inputs = torch.cat([data[:-1], z_stack], 2) | ||||
|             dist, q_logits, _ = self.decoder(inputs, z, None) | ||||
|  | ||||
|         with monit.section('loss'): | ||||
|             kl_loss = self.kl_div_loss(sigma, mu) | ||||
|             reconstruction_loss = self.reconstruction_loss(mask, target, dist, q_logits) | ||||
|             reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits) | ||||
|             loss = self.kl_div_loss_weight * kl_loss + reconstruction_loss | ||||
|  | ||||
|             tracker.add("loss.kl.", kl_loss) | ||||
| @ -346,8 +343,6 @@ class StrokesBatchStep(BatchStepProtocol): | ||||
|                 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip) | ||||
|                 self.optimizer.step() | ||||
|  | ||||
|             # tracker.add('generated', generated_images[0:5]) | ||||
|  | ||||
|         return {'samples': len(data)}, None | ||||
|  | ||||
|  | ||||
| @ -395,7 +390,8 @@ def main(): | ||||
|     experiment.configs(configs, { | ||||
|         'optimizer.optimizer': 'Adam', | ||||
|         'optimizer.learning_rate': 1e-3, | ||||
|         'dataset_name': 'bicycle' | ||||
|         'dataset_name': 'bicycle', | ||||
|         'inner_iterations': 10 | ||||
|     }, 'run') | ||||
|     experiment.start() | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri