mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-29 17:57:14 +08:00
♻️ experiment.configs
This commit is contained in:
@ -153,8 +153,7 @@ def main():
|
|||||||
conf = Configs()
|
conf = Configs()
|
||||||
experiment.create(name='mnist_latest')
|
experiment.create(name='mnist_latest')
|
||||||
experiment.configs(conf, {'optimizer.optimizer': 'Adam',
|
experiment.configs(conf, {'optimizer.optimizer': 'Adam',
|
||||||
'device.cuda_device': 1},
|
'device.cuda_device': 1})
|
||||||
'run')
|
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
conf.run()
|
conf.run()
|
||||||
|
|
||||||
|
|||||||
@ -106,8 +106,7 @@ def main():
|
|||||||
experiment.configs(conf,
|
experiment.configs(conf,
|
||||||
{'discriminator': 'cnn',
|
{'discriminator': 'cnn',
|
||||||
'generator': 'cnn',
|
'generator': 'cnn',
|
||||||
'label_smoothing': 0.01},
|
'label_smoothing': 0.01})
|
||||||
'run')
|
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
conf.run()
|
conf.run()
|
||||||
|
|
||||||
|
|||||||
@ -233,8 +233,7 @@ def main():
|
|||||||
conf = Configs()
|
conf = Configs()
|
||||||
experiment.create(name='mnist_gan', comment='test')
|
experiment.create(name='mnist_gan', comment='test')
|
||||||
experiment.configs(conf,
|
experiment.configs(conf,
|
||||||
{'label_smoothing': 0.01},
|
{'label_smoothing': 0.01})
|
||||||
'run')
|
|
||||||
with experiment.start():
|
with experiment.start():
|
||||||
conf.run()
|
conf.run()
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,7 @@ def load_experiment(run_uuid: str, checkpoint: Optional[int] = None):
|
|||||||
# This experiment is just an evaluation; i.e. nothing is tracked or saved
|
# This experiment is just an evaluation; i.e. nothing is tracked or saved
|
||||||
experiment.evaluate()
|
experiment.evaluate()
|
||||||
# Initialize configurations
|
# Initialize configurations
|
||||||
experiment.configs(conf, conf_dict, 'run')
|
experiment.configs(conf, conf_dict)
|
||||||
# Set models for saving/loading
|
# Set models for saving/loading
|
||||||
experiment.add_pytorch_models(get_modules(conf))
|
experiment.add_pytorch_models(get_modules(conf))
|
||||||
# Specify the experiment to load from
|
# Specify the experiment to load from
|
||||||
|
|||||||
@ -331,10 +331,7 @@ def main():
|
|||||||
'transformer.d_model': 256,
|
'transformer.d_model': 256,
|
||||||
'transformer.d_ff': 1024,
|
'transformer.d_ff': 1024,
|
||||||
'transformer.n_heads': 8,
|
'transformer.n_heads': 8,
|
||||||
'transformer.n_layers': 6},
|
'transformer.n_layers': 6})
|
||||||
# We need to load the function `TrainValidConfigs.run` and
|
|
||||||
# everything that it's dependent on
|
|
||||||
'run')
|
|
||||||
|
|
||||||
# Set models for saving and loading
|
# Set models for saving and loading
|
||||||
experiment.add_pytorch_models(get_modules(conf))
|
experiment.add_pytorch_models(get_modules(conf))
|
||||||
|
|||||||
Reference in New Issue
Block a user