[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -33,6 +33,8 @@ from .utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
gather_state_dict_fast,
get_lora_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
@@ -1137,7 +1139,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
return state_
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@@ -1145,12 +1147,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
if state_dict is None:
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
if self.pp_size > 1:
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu")
if self.pp_rank == 0:
state_dict.update(gathered_lora_state_dict)
state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
if self.coordinator.is_master():
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=state_dict,
)