mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-31 02:39:16 +08:00
📚 notes
This commit is contained in:
@ -611,22 +611,29 @@ def setup_dataloader(self: Configs):
|
|||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
|
"""
|
||||||
|
## Train Cycle GAN
|
||||||
|
"""
|
||||||
|
# Create configurations
|
||||||
conf = Configs()
|
conf = Configs()
|
||||||
|
# Create an experiment
|
||||||
experiment.create(name='cycle_gan')
|
experiment.create(name='cycle_gan')
|
||||||
experiment.configs(conf, {
|
# Calculate configurations.
|
||||||
'dataset_name': 'summer2winter_yosemite'
|
# It will calculate `conf.run` and all other configs required by it.
|
||||||
}, 'run')
|
experiment.configs(conf, {'dataset_name': 'summer2winter_yosemite'}, 'run')
|
||||||
# Register models for saving and loading.
|
# Register models for saving and loading.
|
||||||
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
# `get_modules` gives a dictionary of `nn.Modules` in `conf`.
|
||||||
# You can also specify a custom dictionary of models.
|
# You can also specify a custom dictionary of models.
|
||||||
experiment.add_pytorch_models(get_modules(conf))
|
experiment.add_pytorch_models(get_modules(conf))
|
||||||
|
# Start and watch the experiment
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
|
# Run the training
|
||||||
conf.run()
|
conf.run()
|
||||||
|
|
||||||
|
|
||||||
def plot_image(img: torch.Tensor):
|
def plot_image(img: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Plots an image with matplotlib
|
### Plots an image with matplotlib
|
||||||
"""
|
"""
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
@ -641,7 +648,10 @@ def plot_image(img: torch.Tensor):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def sample():
|
def evaluate():
|
||||||
|
"""
|
||||||
|
## Evaluate trained Cycle GAN
|
||||||
|
"""
|
||||||
# Set the run uuid from the training run
|
# Set the run uuid from the training run
|
||||||
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
trained_run_uuid = 'f73c1164184711eb9190b74249275441'
|
||||||
# Create configs object
|
# Create configs object
|
||||||
@ -704,4 +714,4 @@ def sample():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train()
|
train()
|
||||||
# sample()
|
# evaluate()
|
||||||
|
|||||||
Reference in New Issue
Block a user