mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
28 lines
814 B
Python
28 lines
814 B
Python
import json
|
|
import pathlib
|
|
from typing import Dict
|
|
|
|
from labml import experiment
|
|
from labml_nn.cfr import InfoSet
|
|
|
|
|
|
class InfoSetSaver(experiment.ModelSaver):
|
|
def __init__(self, infosets: Dict[str, InfoSet]):
|
|
self.infosets = infosets
|
|
|
|
def save(self, checkpoint_path: pathlib.Path) -> any:
|
|
data = {key: infoset.to_dict() for key, infoset in self.infosets.items()}
|
|
file_name = f"infosets.json"
|
|
|
|
with open(str(checkpoint_path / file_name), 'w') as f:
|
|
f.write(json.dumps(data))
|
|
|
|
return file_name
|
|
|
|
def load(self, checkpoint_path: pathlib.Path, file_name: str):
|
|
with open(str(checkpoint_path / file_name), 'w') as f:
|
|
data = json.loads(f.read())
|
|
|
|
for key, d in data.items():
|
|
self.infosets[key] = InfoSet.from_dict(d)
|