Fix data path

This commit is contained in:
Varuna Jayasiri
2022-09-13 17:05:52 +05:30
committed by GitHub
parent 7d1550dd67
commit a5686f4709

View File

@ -8,7 +8,8 @@ summary: >
# GPT-NeoX Checkpoints # GPT-NeoX Checkpoints
""" """
from typing import Dict, Union, Tuple from pathlib import Path
from typing import Dict, Union, Tuple, Optional
import torch import torch
from torch import nn from torch import nn
@ -19,12 +20,21 @@ from labml.utils.download import download_file
# Parent url # Parent url
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/' CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
# Download path
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights' _CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None
if not CHECKPOINTS_DOWNLOAD_PATH.exists():
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH) # Download path
def get_checkpoints_download_path():
global _CHECKPOINTS_DOWNLOAD_PATH
if _CHECKPOINTS_DOWNLOAD_PATH is not None:
return _CHECKPOINTS_DOWNLOAD_PATH
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
def get_files_to_download(n_layers: int = 44): def get_files_to_download(n_layers: int = 44):
@ -65,7 +75,7 @@ def download(n_layers: int = 44):
# Log # Log
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)]) logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
# Download # Download
download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f) download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)
def load_checkpoint_files(files: Tuple[str, str]): def load_checkpoint_files(files: Tuple[str, str]):
@ -75,7 +85,7 @@ def load_checkpoint_files(files: Tuple[str, str]):
:param files: pair of files to load :param files: pair of files to load
:return: the loaded parameter tensors :return: the loaded parameter tensors
""" """
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000' checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
with monit.section('Load checkpoint'): with monit.section('Load checkpoint'):
data = [torch.load(checkpoint_path / f) for f in files] data = [torch.load(checkpoint_path / f) for f in files]