diff --git a/labml_nn/neox/checkpoint.py b/labml_nn/neox/checkpoint.py index 28915359..2aac8b1a 100644 --- a/labml_nn/neox/checkpoint.py +++ b/labml_nn/neox/checkpoint.py @@ -36,6 +36,8 @@ def get_checkpoints_download_path(): _CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights' inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH) + return _CHECKPOINTS_DOWNLOAD_PATH + def get_files_to_download(n_layers: int = 44): """