mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	♻️ remove setup
This commit is contained in:
		| @ -328,6 +328,62 @@ class Configs(BaseConfigs): | |||||||
|         # Save grid |         # Save grid | ||||||
|         save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) |         save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) | ||||||
|  |  | ||||||
|  |     def initialize(self): | ||||||
|  |         """ | ||||||
|  |         ## Initialize models and data loaders | ||||||
|  |         """ | ||||||
|  |         input_shape = (self.img_channels, self.img_height, self.img_width) | ||||||
|  |  | ||||||
|  |         # Create the models | ||||||
|  |         self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) | ||||||
|  |         self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) | ||||||
|  |         self.discriminator_x = Discriminator(input_shape).to(self.device) | ||||||
|  |         self.discriminator_y = Discriminator(input_shape).to(self.device) | ||||||
|  |  | ||||||
|  |         # Create the optmizers | ||||||
|  |         self.generator_optimizer = torch.optim.Adam( | ||||||
|  |             itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()), | ||||||
|  |             lr=self.learning_rate, betas=self.adam_betas) | ||||||
|  |         self.discriminator_optimizer = torch.optim.Adam( | ||||||
|  |             itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()), | ||||||
|  |             lr=self.learning_rate, betas=self.adam_betas) | ||||||
|  |  | ||||||
|  |         # Create the learning rate schedules. | ||||||
|  |         # The learning rate stars flat until `decay_start` epochs, | ||||||
|  |         # and then linearly reduces to $0$ at end of training. | ||||||
|  |         decay_epochs = self.epochs - self.decay_start | ||||||
|  |         self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||||
|  |             self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||||
|  |         self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||||
|  |             self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||||
|  |         # Location of the dataset | ||||||
|  |         images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name | ||||||
|  |  | ||||||
|  |         # Image transformations | ||||||
|  |         transforms_ = [ | ||||||
|  |             transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), | ||||||
|  |             transforms.RandomCrop((self.img_height, self.img_width)), | ||||||
|  |             transforms.RandomHorizontalFlip(), | ||||||
|  |             transforms.ToTensor(), | ||||||
|  |             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |         # Training data loader | ||||||
|  |         self.dataloader = DataLoader( | ||||||
|  |             ImageDataset(images_path, transforms_, True, 'train'), | ||||||
|  |             batch_size=self.batch_size, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=self.data_loader_workers, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # Validation data loader | ||||||
|  |         self.valid_dataloader = DataLoader( | ||||||
|  |             ImageDataset(images_path, transforms_, True, "test"), | ||||||
|  |             batch_size=5, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=self.data_loader_workers, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def run(self): |     def run(self): | ||||||
|         """ |         """ | ||||||
|         ## Training |         ## Training | ||||||
| @ -542,74 +598,6 @@ class Configs(BaseConfigs): | |||||||
|         tracker.add({'loss.discriminator': loss_discriminator}) |         tracker.add({'loss.discriminator': loss_discriminator}) | ||||||
|  |  | ||||||
|  |  | ||||||
| @configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y, |  | ||||||
|                 Configs.generator_optimizer, Configs.discriminator_optimizer, |  | ||||||
|                 Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler]) |  | ||||||
| def setup_models(self: Configs): |  | ||||||
|     """ |  | ||||||
|     ## setup the models |  | ||||||
|     """ |  | ||||||
|     input_shape = (self.img_channels, self.img_height, self.img_width) |  | ||||||
|  |  | ||||||
|     # Create the models |  | ||||||
|     self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) |  | ||||||
|     self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) |  | ||||||
|     self.discriminator_x = Discriminator(input_shape).to(self.device) |  | ||||||
|     self.discriminator_y = Discriminator(input_shape).to(self.device) |  | ||||||
|  |  | ||||||
|     # Create the optmizers |  | ||||||
|     self.generator_optimizer = torch.optim.Adam( |  | ||||||
|         itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()), |  | ||||||
|         lr=self.learning_rate, betas=self.adam_betas) |  | ||||||
|     self.discriminator_optimizer = torch.optim.Adam( |  | ||||||
|         itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()), |  | ||||||
|         lr=self.learning_rate, betas=self.adam_betas) |  | ||||||
|  |  | ||||||
|     # Create the learning rate schedules. |  | ||||||
|     # The learning rate stars flat until `decay_start` epochs, |  | ||||||
|     # and then linearly reduces to $0$ at end of training. |  | ||||||
|     decay_epochs = self.epochs - self.decay_start |  | ||||||
|     self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( |  | ||||||
|         self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) |  | ||||||
|     self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( |  | ||||||
|         self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @configs.setup([Configs.dataloader, Configs.valid_dataloader]) |  | ||||||
| def setup_dataloader(self: Configs): |  | ||||||
|     """ |  | ||||||
|     ## setup the data loaders |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     # Location of the dataset |  | ||||||
|     images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name |  | ||||||
|  |  | ||||||
|     # Image transformations |  | ||||||
|     transforms_ = [ |  | ||||||
|         transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), |  | ||||||
|         transforms.RandomCrop((self.img_height, self.img_width)), |  | ||||||
|         transforms.RandomHorizontalFlip(), |  | ||||||
|         transforms.ToTensor(), |  | ||||||
|         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     # Training data loader |  | ||||||
|     self.dataloader = DataLoader( |  | ||||||
|         ImageDataset(images_path, transforms_, True, 'train'), |  | ||||||
|         batch_size=self.batch_size, |  | ||||||
|         shuffle=True, |  | ||||||
|         num_workers=self.data_loader_workers, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     # Validation data loader |  | ||||||
|     self.valid_dataloader = DataLoader( |  | ||||||
|         ImageDataset(images_path, transforms_, True, "test"), |  | ||||||
|         batch_size=5, |  | ||||||
|         shuffle=True, |  | ||||||
|         num_workers=self.data_loader_workers, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def train(): | def train(): | ||||||
|     """ |     """ | ||||||
|     ## Train Cycle GAN |     ## Train Cycle GAN | ||||||
| @ -620,7 +608,9 @@ def train(): | |||||||
|     experiment.create(name='cycle_gan') |     experiment.create(name='cycle_gan') | ||||||
|     # Calculate configurations. |     # Calculate configurations. | ||||||
|     # It will calculate `conf.run` and all other configs required by it. |     # It will calculate `conf.run` and all other configs required by it. | ||||||
|     experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run') |     experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}) | ||||||
|  |     conf.initialize() | ||||||
|  |  | ||||||
|     # Register models for saving and loading. |     # Register models for saving and loading. | ||||||
|     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. |     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. | ||||||
|     # You can also specify a custom dictionary of models. |     # You can also specify a custom dictionary of models. | ||||||
| @ -668,7 +658,9 @@ def evaluate(): | |||||||
|     # If you want other parameters like `dataset_name` you should specify them here. |     # If you want other parameters like `dataset_name` you should specify them here. | ||||||
|     # If you specify nothing all the configurations will be calculated including data loaders. |     # If you specify nothing all the configurations will be calculated including data loaders. | ||||||
|     # Calculation of configurations and their dependencies will happen when you call `experiment.start` |     # Calculation of configurations and their dependencies will happen when you call `experiment.start` | ||||||
|     experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx') |     experiment.configs(conf, conf_dict) | ||||||
|  |     conf.initialize() | ||||||
|  |  | ||||||
|     # Register models for saving and loading. |     # Register models for saving and loading. | ||||||
|     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. |     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. | ||||||
|     # You can also specify a custom dictionary of models. |     # You can also specify a custom dictionary of models. | ||||||
|  | |||||||
| @ -131,6 +131,7 @@ class BivariateGaussianMixture: | |||||||
|     This class adjust temperatures and creates the categorical and gaussian |     This class adjust temperatures and creates the categorical and gaussian | ||||||
|     distributions from the parameters. |     distributions from the parameters. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor, |     def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor, | ||||||
|                  sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor): |                  sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor): | ||||||
|         self.pi_logits = pi_logits |         self.pi_logits = pi_logits | ||||||
| @ -595,6 +596,34 @@ class Configs(TrainValidConfigs): | |||||||
|  |  | ||||||
|     epochs = 100 |     epochs = 100 | ||||||
|  |  | ||||||
|  |     def initialize(self): | ||||||
|  |         # Initialize encoder & decoder | ||||||
|  |         self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) | ||||||
|  |         self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) | ||||||
|  |  | ||||||
|  |         # Set optimizer, things like type of optimizer and learning rate are configurable | ||||||
|  |         optimizer = OptimizerConfigs() | ||||||
|  |         optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) | ||||||
|  |         self.optimizer = optimizer | ||||||
|  |  | ||||||
|  |         # Create sampler | ||||||
|  |         self.sampler = Sampler(self.encoder, self.decoder) | ||||||
|  |  | ||||||
|  |         # `npz` file path is `data/sketch/[DATASET NAME].npz` | ||||||
|  |         path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' | ||||||
|  |         # Load the numpy file. | ||||||
|  |         dataset = np.load(str(path), encoding='latin1', allow_pickle=True) | ||||||
|  |  | ||||||
|  |         # Create training dataset | ||||||
|  |         self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) | ||||||
|  |         # Create validation dataset | ||||||
|  |         self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) | ||||||
|  |  | ||||||
|  |         # Create training data loader | ||||||
|  |         self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) | ||||||
|  |         # Create validation data loader | ||||||
|  |         self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) | ||||||
|  |  | ||||||
|     def sample(self): |     def sample(self): | ||||||
|         # Randomly pick a sample from validation dataset to encoder |         # Randomly pick a sample from validation dataset to encoder | ||||||
|         data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))] |         data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))] | ||||||
| @ -604,37 +633,6 @@ class Configs(TrainValidConfigs): | |||||||
|         self.sampler.sample(data, self.temperature) |         self.sampler.sample(data, self.temperature) | ||||||
|  |  | ||||||
|  |  | ||||||
| @setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler, |  | ||||||
|         Configs.train_dataset, Configs.train_loader, |  | ||||||
|         Configs.valid_dataset, Configs.valid_loader]) |  | ||||||
| def setup_all(self: Configs): |  | ||||||
|     # Initialize encoder & decoder |  | ||||||
|     self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) |  | ||||||
|     self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) |  | ||||||
|  |  | ||||||
|     # Set optimizer, things like type of optimizer and learning rate are configurable |  | ||||||
|     self.optimizer = OptimizerConfigs() |  | ||||||
|     self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) |  | ||||||
|  |  | ||||||
|     # Create sampler |  | ||||||
|     self.sampler = Sampler(self.encoder, self.decoder) |  | ||||||
|  |  | ||||||
|     # `npz` file path is `data/sketch/[DATASET NAME].npz` |  | ||||||
|     path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' |  | ||||||
|     # Load the numpy file. |  | ||||||
|     dataset = np.load(str(path), encoding='latin1', allow_pickle=True) |  | ||||||
|  |  | ||||||
|     # Create training dataset |  | ||||||
|     self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) |  | ||||||
|     # Create validation dataset |  | ||||||
|     self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) |  | ||||||
|  |  | ||||||
|     # Create training data loader |  | ||||||
|     self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) |  | ||||||
|     # Create validation data loader |  | ||||||
|     self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @option(Configs.batch_step) | @option(Configs.batch_step) | ||||||
| def strokes_batch_step(c: Configs): | def strokes_batch_step(c: Configs): | ||||||
|     """Set Strokes Training and Validation module""" |     """Set Strokes Training and Validation module""" | ||||||
| @ -655,7 +653,9 @@ def main(): | |||||||
|         'dataset_name': 'bicycle', |         'dataset_name': 'bicycle', | ||||||
|         # Number of inner iterations within an epoch to switch between training, validation and sampling. |         # Number of inner iterations within an epoch to switch between training, validation and sampling. | ||||||
|         'inner_iterations': 10 |         'inner_iterations': 10 | ||||||
|     }, 'run') |     }) | ||||||
|  |  | ||||||
|  |     configs.initialize() | ||||||
|  |  | ||||||
|     with experiment.start(): |     with experiment.start(): | ||||||
|         # Run the experiment |         # Run the experiment | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri