mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-30 10:18:50 +08:00 
			
		
		
		
	♻️ remove setup
This commit is contained in:
		| @ -328,6 +328,62 @@ class Configs(BaseConfigs): | ||||
|         # Save grid | ||||
|         save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) | ||||
|  | ||||
|     def initialize(self): | ||||
|         """ | ||||
|         ## Initialize models and data loaders | ||||
|         """ | ||||
|         input_shape = (self.img_channels, self.img_height, self.img_width) | ||||
|  | ||||
|         # Create the models | ||||
|         self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) | ||||
|         self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) | ||||
|         self.discriminator_x = Discriminator(input_shape).to(self.device) | ||||
|         self.discriminator_y = Discriminator(input_shape).to(self.device) | ||||
|  | ||||
|         # Create the optmizers | ||||
|         self.generator_optimizer = torch.optim.Adam( | ||||
|             itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()), | ||||
|             lr=self.learning_rate, betas=self.adam_betas) | ||||
|         self.discriminator_optimizer = torch.optim.Adam( | ||||
|             itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()), | ||||
|             lr=self.learning_rate, betas=self.adam_betas) | ||||
|  | ||||
|         # Create the learning rate schedules. | ||||
|         # The learning rate stars flat until `decay_start` epochs, | ||||
|         # and then linearly reduces to $0$ at end of training. | ||||
|         decay_epochs = self.epochs - self.decay_start | ||||
|         self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||
|             self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||
|         self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||
|             self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||
|         # Location of the dataset | ||||
|         images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name | ||||
|  | ||||
|         # Image transformations | ||||
|         transforms_ = [ | ||||
|             transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), | ||||
|             transforms.RandomCrop((self.img_height, self.img_width)), | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||||
|         ] | ||||
|  | ||||
|         # Training data loader | ||||
|         self.dataloader = DataLoader( | ||||
|             ImageDataset(images_path, transforms_, True, 'train'), | ||||
|             batch_size=self.batch_size, | ||||
|             shuffle=True, | ||||
|             num_workers=self.data_loader_workers, | ||||
|         ) | ||||
|  | ||||
|         # Validation data loader | ||||
|         self.valid_dataloader = DataLoader( | ||||
|             ImageDataset(images_path, transforms_, True, "test"), | ||||
|             batch_size=5, | ||||
|             shuffle=True, | ||||
|             num_workers=self.data_loader_workers, | ||||
|         ) | ||||
|  | ||||
|     def run(self): | ||||
|         """ | ||||
|         ## Training | ||||
| @ -542,74 +598,6 @@ class Configs(BaseConfigs): | ||||
|         tracker.add({'loss.discriminator': loss_discriminator}) | ||||
|  | ||||
|  | ||||
| @configs.setup([Configs.generator_xy, Configs.generator_yx, Configs.discriminator_x, Configs.discriminator_y, | ||||
|                 Configs.generator_optimizer, Configs.discriminator_optimizer, | ||||
|                 Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler]) | ||||
| def setup_models(self: Configs): | ||||
|     """ | ||||
|     ## setup the models | ||||
|     """ | ||||
|     input_shape = (self.img_channels, self.img_height, self.img_width) | ||||
|  | ||||
|     # Create the models | ||||
|     self.generator_xy = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) | ||||
|     self.generator_yx = GeneratorResNet(self.img_channels, self.n_residual_blocks).to(self.device) | ||||
|     self.discriminator_x = Discriminator(input_shape).to(self.device) | ||||
|     self.discriminator_y = Discriminator(input_shape).to(self.device) | ||||
|  | ||||
|     # Create the optmizers | ||||
|     self.generator_optimizer = torch.optim.Adam( | ||||
|         itertools.chain(self.generator_xy.parameters(), self.generator_yx.parameters()), | ||||
|         lr=self.learning_rate, betas=self.adam_betas) | ||||
|     self.discriminator_optimizer = torch.optim.Adam( | ||||
|         itertools.chain(self.discriminator_x.parameters(), self.discriminator_y.parameters()), | ||||
|         lr=self.learning_rate, betas=self.adam_betas) | ||||
|  | ||||
|     # Create the learning rate schedules. | ||||
|     # The learning rate stars flat until `decay_start` epochs, | ||||
|     # and then linearly reduces to $0$ at end of training. | ||||
|     decay_epochs = self.epochs - self.decay_start | ||||
|     self.generator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||
|         self.generator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||
|     self.discriminator_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | ||||
|         self.discriminator_optimizer, lr_lambda=lambda e: 1.0 - max(0, e - self.decay_start) / decay_epochs) | ||||
|  | ||||
|  | ||||
| @configs.setup([Configs.dataloader, Configs.valid_dataloader]) | ||||
| def setup_dataloader(self: Configs): | ||||
|     """ | ||||
|     ## setup the data loaders | ||||
|     """ | ||||
|  | ||||
|     # Location of the dataset | ||||
|     images_path = lab.get_data_path() / 'cycle_gan' / self.dataset_name | ||||
|  | ||||
|     # Image transformations | ||||
|     transforms_ = [ | ||||
|         transforms.Resize(int(self.img_height * 1.12), Image.BICUBIC), | ||||
|         transforms.RandomCrop((self.img_height, self.img_width)), | ||||
|         transforms.RandomHorizontalFlip(), | ||||
|         transforms.ToTensor(), | ||||
|         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||||
|     ] | ||||
|  | ||||
|     # Training data loader | ||||
|     self.dataloader = DataLoader( | ||||
|         ImageDataset(images_path, transforms_, True, 'train'), | ||||
|         batch_size=self.batch_size, | ||||
|         shuffle=True, | ||||
|         num_workers=self.data_loader_workers, | ||||
|     ) | ||||
|  | ||||
|     # Validation data loader | ||||
|     self.valid_dataloader = DataLoader( | ||||
|         ImageDataset(images_path, transforms_, True, "test"), | ||||
|         batch_size=5, | ||||
|         shuffle=True, | ||||
|         num_workers=self.data_loader_workers, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def train(): | ||||
|     """ | ||||
|     ## Train Cycle GAN | ||||
| @ -620,7 +608,9 @@ def train(): | ||||
|     experiment.create(name='cycle_gan') | ||||
|     # Calculate configurations. | ||||
|     # It will calculate `conf.run` and all other configs required by it. | ||||
|     experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run') | ||||
|     experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}) | ||||
|     conf.initialize() | ||||
|  | ||||
|     # Register models for saving and loading. | ||||
|     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. | ||||
|     # You can also specify a custom dictionary of models. | ||||
| @ -668,7 +658,9 @@ def evaluate(): | ||||
|     # If you want other parameters like `dataset_name` you should specify them here. | ||||
|     # If you specify nothing all the configurations will be calculated including data loaders. | ||||
|     # Calculation of configurations and their dependencies will happen when you call `experiment.start` | ||||
|     experiment.configs(conf, conf_dict, 'generator_xy', 'generator_yx') | ||||
|     experiment.configs(conf, conf_dict) | ||||
|     conf.initialize() | ||||
|  | ||||
|     # Register models for saving and loading. | ||||
|     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. | ||||
|     # You can also specify a custom dictionary of models. | ||||
|  | ||||
| @ -131,6 +131,7 @@ class BivariateGaussianMixture: | ||||
|     This class adjust temperatures and creates the categorical and gaussian | ||||
|     distributions from the parameters. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor, | ||||
|                  sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor): | ||||
|         self.pi_logits = pi_logits | ||||
| @ -595,6 +596,34 @@ class Configs(TrainValidConfigs): | ||||
|  | ||||
|     epochs = 100 | ||||
|  | ||||
|     def initialize(self): | ||||
|         # Initialize encoder & decoder | ||||
|         self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) | ||||
|         self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) | ||||
|  | ||||
|         # Set optimizer, things like type of optimizer and learning rate are configurable | ||||
|         optimizer = OptimizerConfigs() | ||||
|         optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) | ||||
|         self.optimizer = optimizer | ||||
|  | ||||
|         # Create sampler | ||||
|         self.sampler = Sampler(self.encoder, self.decoder) | ||||
|  | ||||
|         # `npz` file path is `data/sketch/[DATASET NAME].npz` | ||||
|         path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' | ||||
|         # Load the numpy file. | ||||
|         dataset = np.load(str(path), encoding='latin1', allow_pickle=True) | ||||
|  | ||||
|         # Create training dataset | ||||
|         self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) | ||||
|         # Create validation dataset | ||||
|         self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) | ||||
|  | ||||
|         # Create training data loader | ||||
|         self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) | ||||
|         # Create validation data loader | ||||
|         self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) | ||||
|  | ||||
|     def sample(self): | ||||
|         # Randomly pick a sample from validation dataset to encoder | ||||
|         data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))] | ||||
| @ -604,37 +633,6 @@ class Configs(TrainValidConfigs): | ||||
|         self.sampler.sample(data, self.temperature) | ||||
|  | ||||
|  | ||||
| @setup([Configs.encoder, Configs.decoder, Configs.optimizer, Configs.sampler, | ||||
|         Configs.train_dataset, Configs.train_loader, | ||||
|         Configs.valid_dataset, Configs.valid_loader]) | ||||
| def setup_all(self: Configs): | ||||
|     # Initialize encoder & decoder | ||||
|     self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device) | ||||
|     self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device) | ||||
|  | ||||
|     # Set optimizer, things like type of optimizer and learning rate are configurable | ||||
|     self.optimizer = OptimizerConfigs() | ||||
|     self.optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters()) | ||||
|  | ||||
|     # Create sampler | ||||
|     self.sampler = Sampler(self.encoder, self.decoder) | ||||
|  | ||||
|     # `npz` file path is `data/sketch/[DATASET NAME].npz` | ||||
|     path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz' | ||||
|     # Load the numpy file. | ||||
|     dataset = np.load(str(path), encoding='latin1', allow_pickle=True) | ||||
|  | ||||
|     # Create training dataset | ||||
|     self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length) | ||||
|     # Create validation dataset | ||||
|     self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale) | ||||
|  | ||||
|     # Create training data loader | ||||
|     self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True) | ||||
|     # Create validation data loader | ||||
|     self.valid_loader = DataLoader(self.valid_dataset, self.batch_size) | ||||
|  | ||||
|  | ||||
| @option(Configs.batch_step) | ||||
| def strokes_batch_step(c: Configs): | ||||
|     """Set Strokes Training and Validation module""" | ||||
| @ -655,7 +653,9 @@ def main(): | ||||
|         'dataset_name': 'bicycle', | ||||
|         # Number of inner iterations within an epoch to switch between training, validation and sampling. | ||||
|         'inner_iterations': 10 | ||||
|     }, 'run') | ||||
|     }) | ||||
|  | ||||
|     configs.initialize() | ||||
|  | ||||
|     with experiment.start(): | ||||
|         # Run the experiment | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri