[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)

* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
This commit is contained in:
Hongxin Liu
2025-02-14 14:48:54 +08:00
committed by GitHub
parent ec73f1b5e2
commit 014837e725
21 changed files with 478 additions and 91 deletions

View File

@@ -2,6 +2,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from peft import PeftModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@@ -166,7 +167,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
use_safetensors: bool = False,
state_dict: Optional[dict] = None,
) -> None:
"""
Save the lora adapters and adapter configuration file to checkpoint directory.
@@ -174,15 +179,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
if self.coordinator.is_master():
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
state_dict=state_dict,
)
@@ -191,8 +198,11 @@ class TorchDDPModel(ModelWrapper):
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
model = self.module.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model
class TorchDDPPlugin(DPPluginBase):