diff --git a/labml_nn/gan/cycle_gan.py b/labml_nn/gan/cycle_gan.py index ac6ac51a..fe88fcb0 100644 --- a/labml_nn/gan/cycle_gan.py +++ b/labml_nn/gan/cycle_gan.py @@ -611,22 +611,29 @@ def setup_dataloader(self: Configs): def train(): + """ + ## Train Cycle GAN + """ + # Create configurations conf = Configs() + # Create an experiment experiment.create(name='cycle_gan') - experiment.configs(conf, { - 'dataset_name': 'summer2winter_yosemite' - }, 'run') + # Calculate configurations. + # It will calculate `conf.run` and all other configs required by it. + experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run') # Register models for saving and loading. # `get_modules` gives a dictionary of `nn.Modules` in `conf`. # You can also specify a custom dictionary of models. experiment.add_pytorch_models(get_modules(conf)) + # Start and watch the experiment with experiment.start(): + # Run the training conf.run() def plot_image(img: torch.Tensor): """ - Plots an image with matplotlib + ### Plots an image with matplotlib """ from matplotlib import pyplot as plt @@ -641,7 +648,10 @@ def plot_image(img: torch.Tensor): plt.show() -def sample(): +def evaluate(): + """ + ## Evaluate trained Cycle GAN + """ # Set the run uuid from the training run trained_run_uuid = 'f73c1164184711eb9190b74249275441' # Create configs object @@ -704,4 +714,4 @@ def sample(): if __name__ == '__main__': train() - # sample() + # evaluate()