1import json
2import pathlib
3from typing import Dict
4
5from labml import experiment
6from labml_nn.cfr import InfoSet9class InfoSetSaver(experiment.ModelSaver):10 def __init__(self, infosets: Dict[str, InfoSet]):
11 self.infosets = infosets13 def save(self, checkpoint_path: pathlib.Path) -> any:
14 data = {key: infoset.to_dict() for key, infoset in self.infosets.items()}
15 file_name = f"infosets.json"
16
17 with open(str(checkpoint_path / file_name), 'w') as f:
18 f.write(json.dumps(data))
19
20 return file_name22 def load(self, checkpoint_path: pathlib.Path, file_name: str):
23 with open(str(checkpoint_path / file_name), 'w') as f:
24 data = json.loads(f.read())
25
26 for key, d in data.items():
27 self.infosets[key] = InfoSet.from_dict(d)