[lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-02 10:36:58 +08:00
committed by GitHub
parent 19d1510ea2
commit 75c963686f
4 changed files with 44 additions and 5 deletions

View File

@@ -947,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_[k] = v.detach().clone().to(device)
return state_
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)