mirror of
				https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
				synced 2025-10-31 02:39:16 +08:00 
			
		
		
		
	📚 notes
This commit is contained in:
		| @ -611,22 +611,29 @@ def setup_dataloader(self: Configs): | |||||||
|  |  | ||||||
|  |  | ||||||
| def train(): | def train(): | ||||||
|  |     """ | ||||||
|  |     ## Train Cycle GAN | ||||||
|  |     """ | ||||||
|  |     # Create configurations | ||||||
|     conf = Configs() |     conf = Configs() | ||||||
|  |     # Create an experiment | ||||||
|     experiment.create(name='cycle_gan') |     experiment.create(name='cycle_gan') | ||||||
|     experiment.configs(conf, { |     # Calculate configurations. | ||||||
|         'dataset_name': 'summer2winter_yosemite' |     # It will calculate `conf.run` and all other configs required by it. | ||||||
|     }, 'run') |     experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run') | ||||||
|     # Register models for saving and loading. |     # Register models for saving and loading. | ||||||
|     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. |     # `get_modules` gives a dictionary of `nn.Modules` in `conf`. | ||||||
|     # You can also specify a custom dictionary of models. |     # You can also specify a custom dictionary of models. | ||||||
|     experiment.add_pytorch_models(get_modules(conf)) |     experiment.add_pytorch_models(get_modules(conf)) | ||||||
|  |     # Start and watch the experiment | ||||||
|     with experiment.start(): |     with experiment.start(): | ||||||
|  |         # Run the training | ||||||
|         conf.run() |         conf.run() | ||||||
|  |  | ||||||
|  |  | ||||||
| def plot_image(img: torch.Tensor): | def plot_image(img: torch.Tensor): | ||||||
|     """ |     """ | ||||||
|     Plots an image with matplotlib |     ### Plots an image with matplotlib | ||||||
|     """ |     """ | ||||||
|     from matplotlib import pyplot as plt |     from matplotlib import pyplot as plt | ||||||
|  |  | ||||||
| @ -641,7 +648,10 @@ def plot_image(img: torch.Tensor): | |||||||
|     plt.show() |     plt.show() | ||||||
|  |  | ||||||
|  |  | ||||||
| def sample(): | def evaluate(): | ||||||
|  |     """ | ||||||
|  |     ## Evaluate trained Cycle GAN | ||||||
|  |     """ | ||||||
|     # Set the run uuid from the training run |     # Set the run uuid from the training run | ||||||
|     trained_run_uuid = 'f73c1164184711eb9190b74249275441' |     trained_run_uuid = 'f73c1164184711eb9190b74249275441' | ||||||
|     # Create configs object |     # Create configs object | ||||||
| @ -704,4 +714,4 @@ def sample(): | |||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     train() |     train() | ||||||
|     # sample() |     # evaluate() | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	 Varuna Jayasiri
					Varuna Jayasiri