mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-06 15:22:21 +08:00
check loaded weights
This commit is contained in:
@ -110,7 +110,11 @@ class Trainer(BaseConfigs):
|
|||||||
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
|
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
|
||||||
|
|
||||||
# Load out model. We use `strict = False` because the state does not have LoRA weights
|
# Load out model. We use `strict = False` because the state does not have LoRA weights
|
||||||
self.model.load_state_dict(new_state_dict, strict=False)
|
missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
|
||||||
|
|
||||||
|
# make sure that only lora weights are not loaded
|
||||||
|
assert all('lora' in key for key in missing_keys)
|
||||||
|
assert not unexpected_keys
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user