mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
This commit is contained in:
@@ -359,23 +359,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
)
|
||||
super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
Reference in New Issue
Block a user