📚 notes

This commit is contained in:
Varuna Jayasiri
2020-10-28 11:35:06 +05:30
parent bb297cf761
commit fda13baee7

View File

@ -611,22 +611,29 @@ def setup_dataloader(self: Configs):
def train(): def train():
"""
## Train Cycle GAN
"""
# Create configurations
conf = Configs() conf = Configs()
# Create an experiment
experiment.create(name='cycle_gan') experiment.create(name='cycle_gan')
experiment.configs(conf, { # Calculate configurations.
'dataset_name': 'summer2winter_yosemite' # It will calculate `conf.run` and all other configs required by it.
}, 'run') experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run')
# Register models for saving and loading. # Register models for saving and loading.
# `get_modules` gives a dictionary of `nn.Modules` in `conf`. # `get_modules` gives a dictionary of `nn.Modules` in `conf`.
# You can also specify a custom dictionary of models. # You can also specify a custom dictionary of models.
experiment.add_pytorch_models(get_modules(conf)) experiment.add_pytorch_models(get_modules(conf))
# Start and watch the experiment
with experiment.start(): with experiment.start():
# Run the training
conf.run() conf.run()
def plot_image(img: torch.Tensor): def plot_image(img: torch.Tensor):
""" """
Plots an image with matplotlib ### Plots an image with matplotlib
""" """
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
@ -641,7 +648,10 @@ def plot_image(img: torch.Tensor):
plt.show() plt.show()
def sample(): def evaluate():
"""
## Evaluate trained Cycle GAN
"""
# Set the run uuid from the training run # Set the run uuid from the training run
trained_run_uuid = 'f73c1164184711eb9190b74249275441' trained_run_uuid = 'f73c1164184711eb9190b74249275441'
# Create configs object # Create configs object
@ -704,4 +714,4 @@ def sample():
if __name__ == '__main__': if __name__ == '__main__':
train() train()
# sample() # evaluate()