mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-28 20:53:41 +08:00
Fix data path
This commit is contained in:
@ -8,7 +8,8 @@ summary: >
|
|||||||
# GPT-NeoX Checkpoints
|
# GPT-NeoX Checkpoints
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from typing import Dict, Union, Tuple
|
from pathlib import Path
|
||||||
|
from typing import Dict, Union, Tuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -19,12 +20,21 @@ from labml.utils.download import download_file
|
|||||||
|
|
||||||
# Parent url
|
# Parent url
|
||||||
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
|
CHECKPOINTS_URL = 'https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/'
|
||||||
# Download path
|
|
||||||
|
|
||||||
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
|
_CHECKPOINTS_DOWNLOAD_PATH: Optional[Path] = None
|
||||||
if not CHECKPOINTS_DOWNLOAD_PATH.exists():
|
|
||||||
CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
|
|
||||||
inspect(neox_checkpoint_path=CHECKPOINTS_DOWNLOAD_PATH)
|
# Download path
|
||||||
|
def get_checkpoints_download_path():
|
||||||
|
global _CHECKPOINTS_DOWNLOAD_PATH
|
||||||
|
|
||||||
|
if _CHECKPOINTS_DOWNLOAD_PATH is not None:
|
||||||
|
return _CHECKPOINTS_DOWNLOAD_PATH
|
||||||
|
|
||||||
|
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox_fast' / 'slim_weights'
|
||||||
|
if not _CHECKPOINTS_DOWNLOAD_PATH.exists():
|
||||||
|
_CHECKPOINTS_DOWNLOAD_PATH = lab.get_data_path() / 'neox' / 'slim_weights'
|
||||||
|
inspect(neox_checkpoint_path=_CHECKPOINTS_DOWNLOAD_PATH)
|
||||||
|
|
||||||
|
|
||||||
def get_files_to_download(n_layers: int = 44):
|
def get_files_to_download(n_layers: int = 44):
|
||||||
@ -65,7 +75,7 @@ def download(n_layers: int = 44):
|
|||||||
# Log
|
# Log
|
||||||
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
|
logger.log(['Downloading ', (f'{i + 1 :3d}/{len(files)}', Text.meta), ': ', (f, Text.value)])
|
||||||
# Download
|
# Download
|
||||||
download_file(CHECKPOINTS_URL + f, CHECKPOINTS_DOWNLOAD_PATH / f)
|
download_file(CHECKPOINTS_URL + f, get_checkpoints_download_path() / f)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_files(files: Tuple[str, str]):
|
def load_checkpoint_files(files: Tuple[str, str]):
|
||||||
@ -75,7 +85,7 @@ def load_checkpoint_files(files: Tuple[str, str]):
|
|||||||
:param files: pair of files to load
|
:param files: pair of files to load
|
||||||
:return: the loaded parameter tensors
|
:return: the loaded parameter tensors
|
||||||
"""
|
"""
|
||||||
checkpoint_path = CHECKPOINTS_DOWNLOAD_PATH / 'global_step150000'
|
checkpoint_path = get_checkpoints_download_path() / 'global_step150000'
|
||||||
with monit.section('Load checkpoint'):
|
with monit.section('Load checkpoint'):
|
||||||
data = [torch.load(checkpoint_path / f) for f in files]
|
data = [torch.load(checkpoint_path / f) for f in files]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user