1import json
2import pathlib
3from typing import Dict
4
5from labml import experiment
6from labml_nn.cfr import InfoSet
9class InfoSetSaver(experiment.ModelSaver):
10    def __init__(self, infosets: Dict[str, InfoSet]):
11        self.infosets = infosets
13    def save(self, checkpoint_path: pathlib.Path) -> any:
14        data = {key: infoset.to_dict() for key, infoset in self.infosets.items()}
15        file_name = f"infosets.json"
16
17        with open(str(checkpoint_path / file_name), 'w') as f:
18            f.write(json.dumps(data))
19
20        return file_name
22    def load(self, checkpoint_path: pathlib.Path, file_name: str):
23        with open(str(checkpoint_path / file_name), 'w') as f:
24            data = json.loads(f.read())
25
26        for key, d in data.items():
27            self.infosets[key] = InfoSet.from_dict(d)