mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 09:38:56 +08:00
📚 notes
This commit is contained in:
@ -611,22 +611,29 @@ def setup_dataloader(self: Configs):
|
||||
|
||||
|
||||
def train():
|
||||
"""
|
||||
## Train Cycle GAN
|
||||
"""
|
||||
# Create configurations
|
||||
conf = Configs()
|
||||
# Create an experiment
|
||||
experiment.create(name='cycle_gan')
|
||||
experiment.configs(conf, {
|
||||
'dataset_name': 'summer2winter_yosemite'
|
||||
}, 'run')
|
||||
# Calculate configurations.
|
||||
# It will calculate `conf.run` and all other configs required by it.
|
||||
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run')
|
||||
# Register models for saving and loading.
|
||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||
# You can also specify a custom dictionary of models.
|
||||
experiment.add_pytorch_models(get_modules(conf))
|
||||
# Start and watch the experiment
|
||||
with experiment.start():
|
||||
# Run the training
|
||||
conf.run()
|
||||
|
||||
|
||||
def plot_image(img: torch.Tensor):
|
||||
"""
|
||||
Plots an image with matplotlib
|
||||
### Plots an image with matplotlib
|
||||
"""
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
@ -641,7 +648,10 @@ def plot_image(img: torch.Tensor):
|
||||
plt.show()
|
||||
|
||||
|
||||
def sample():
|
||||
def evaluate():
|
||||
"""
|
||||
## Evaluate trained Cycle GAN
|
||||
"""
|
||||
# Set the run uuid from the training run
|
||||
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
||||
# Create configs object
|
||||
@ -704,4 +714,4 @@ def sample():
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
# sample()
|
||||
# evaluate()
|
||||
|
||||
Reference in New Issue
Block a user