mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 18:58:43 +08:00 
			
		
		
		
	comments
This commit is contained in:
		| @ -229,7 +229,8 @@ class ReplayBuffer: | ||||
|         self.max_size = max_size | ||||
|         self.data = [] | ||||
|  | ||||
|     def push_and_pop(self, data): | ||||
|     def push_and_pop(self, data: torch.Tensor): | ||||
|         """Add/retrieve an image""" | ||||
|         data = data.detach() | ||||
|         res = [] | ||||
|         for element in data: | ||||
| @ -247,14 +248,17 @@ class ReplayBuffer: | ||||
|  | ||||
|  | ||||
| class Configs(BaseConfigs): | ||||
|     """This is the configurations for the experiment""" | ||||
|     """## Configurations""" | ||||
|  | ||||
|     # `DeviceConfigs` will pick a GPU if available | ||||
|     device: torch.device = DeviceConfigs() | ||||
|  | ||||
|     # Hyper-parameters | ||||
|     epochs: int = 200 | ||||
|     dataset_name: str = 'monet2photo' | ||||
|     batch_size: int = 1 | ||||
|  | ||||
|     data_loader_workers = 8 | ||||
|     is_save_models = True | ||||
|  | ||||
|     learning_rate = 0.0002 | ||||
|     adam_betas = (0.5, 0.999) | ||||
| @ -268,30 +272,35 @@ class Configs(BaseConfigs): | ||||
|     cycle_loss = torch.nn.L1Loss() | ||||
|     identity_loss = torch.nn.L1Loss() | ||||
|  | ||||
|     batch_step = 'cycle_gan_batch_step' | ||||
|  | ||||
|     # Image dimensions | ||||
|     img_height = 256 | ||||
|     img_width = 256 | ||||
|     img_channels = 3 | ||||
|  | ||||
|     # Number of residual blocks in the generator | ||||
|     n_residual_blocks = 9 | ||||
|  | ||||
|     # Loss coefficients | ||||
|     cyclic_loss_coefficient = 10.0 | ||||
|     identity_loss_coefficient = 5. | ||||
|  | ||||
|     sample_interval = 500 | ||||
|  | ||||
|     # Models | ||||
|     generator_xy: GeneratorResNet | ||||
|     generator_yx: GeneratorResNet | ||||
|     discriminator_x: Discriminator | ||||
|     discriminator_y: Discriminator | ||||
|  | ||||
|     # Optimizers | ||||
|     generator_optimizer: torch.optim.Adam | ||||
|     discriminator_optimizer: torch.optim.Adam | ||||
|  | ||||
|     # Learning rate schedules | ||||
|     generator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR | ||||
|     discriminator_lr_scheduler: torch.optim.lr_scheduler.LambdaLR | ||||
|  | ||||
|     # Data loaders | ||||
|     dataloader: DataLoader | ||||
|     valid_dataloader: DataLoader | ||||
|  | ||||
| @ -311,12 +320,16 @@ class Configs(BaseConfigs): | ||||
|             gen_x = make_grid(gen_x, nrow=5, normalize=True) | ||||
|             gen_y = make_grid(gen_y, nrow=5, normalize=True) | ||||
|  | ||||
|             # arange images along y-axis | ||||
|             # Arrange images along y-axis | ||||
|             image_grid = torch.cat((data_x, gen_y, data_y, gen_x), 1) | ||||
|  | ||||
|         # Save grid | ||||
|         save_image(image_grid, f"images/{self.dataset_name}/{n}.png", normalize=False) | ||||
|  | ||||
|     def run(self): | ||||
|         """ | ||||
|         ## Training | ||||
|  | ||||
|         We aim to solve: | ||||
|         $$G^{*}, F^{*} = \arg \min_{G,F} \max_{D_X, D_Y} \mathcal{L}(G, F, D_X, D_Y)$$ | ||||
|  | ||||
| @ -447,7 +460,7 @@ class Configs(BaseConfigs): | ||||
|  | ||||
|     def optimize_generators(self, data_x: torch.Tensor, data_y: torch.Tensor, true_labels: torch.Tensor): | ||||
|         """ | ||||
|         Optimize the generators with identity, gan and cycle losses. | ||||
|         ### Optimize the generators with identity, gan and cycle losses. | ||||
|         """ | ||||
|  | ||||
|         #  Change to training mode | ||||
| @ -501,7 +514,7 @@ class Configs(BaseConfigs): | ||||
|                                gen_x: torch.Tensor, gen_y: torch.Tensor, | ||||
|                                true_labels: torch.Tensor, false_labels: torch.Tensor): | ||||
|         """ | ||||
|         Optimize the discriminators with gan loss. | ||||
|         ### Optimize the discriminators with gan loss. | ||||
|         """ | ||||
|         # GAN Loss | ||||
|         # \begin{align} | ||||
| @ -529,7 +542,7 @@ class Configs(BaseConfigs): | ||||
|                 Configs.generator_lr_scheduler, Configs.discriminator_lr_scheduler]) | ||||
| def setup_models(self: Configs): | ||||
|     """ | ||||
|     This method setup the models | ||||
|     ## setup the models | ||||
|     """ | ||||
|     input_shape = (self.img_channels, self.img_height, self.img_width) | ||||
|  | ||||
| @ -560,7 +573,7 @@ def setup_models(self: Configs): | ||||
| @configs.setup([Configs.dataloader, Configs.valid_dataloader]) | ||||
| def setup_dataloader(self: Configs): | ||||
|     """ | ||||
|     This method setup the data loaders | ||||
|     ## setup the data loaders | ||||
|     """ | ||||
|  | ||||
|     # Location of the dataset | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri