[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -359,23 +359,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)
class LowLevelZeroPlugin(DPPluginBase):