mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin * fix * fix * fix * fix
This commit is contained in:
@@ -947,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
||||
|
Reference in New Issue
Block a user