check loaded weights

This commit is contained in:
lakshith
2024-08-21 09:54:53 +05:30
parent 9e1b35716d
commit 24bd64af7c

View File

@ -110,7 +110,11 @@ class Trainer(BaseConfigs):
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
# Load out model. We use `strict = False` because the state does not have LoRA weights
self.model.load_state_dict(new_state_dict, strict=False)
missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
# make sure that only lora weights are not loaded
assert all('lora' in key for key in missing_keys)
assert not unexpected_keys
def initialize(self):
"""