fix lora ckpt save format (ColoTensor to Tensor)

This commit is contained in:
BurkeHulk
2024-10-21 13:55:43 +08:00
parent 5ddad486ca
commit b10339df7c
3 changed files with 11 additions and 3 deletions

View File

@@ -290,7 +290,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
class LowLevelZeroPlugin(DPPluginBase):