mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
This commit is contained in:
@@ -33,6 +33,8 @@ from .utils import (
|
||||
async_save_state_dict_shards,
|
||||
create_pinned_state_dict,
|
||||
gather_distributed_param,
|
||||
gather_state_dict_fast,
|
||||
get_lora_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
is_safetensors_available,
|
||||
@@ -1137,7 +1139,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
return state_
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
@@ -1145,12 +1147,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
peft_model = model.unwrap(unwrap_peft=False)
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
)
|
||||
if state_dict is None:
|
||||
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
|
||||
if self.pp_size > 1:
|
||||
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
|
||||
gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu")
|
||||
if self.pp_rank == 0:
|
||||
state_dict.update(gathered_lora_state_dict)
|
||||
state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||
if self.coordinator.is_master():
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
|
Reference in New Issue
Block a user