mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft import PeftModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
@@ -166,7 +167,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
self,
|
||||
model: Union[nn.Module, ModelWrapper],
|
||||
checkpoint: str,
|
||||
use_safetensors: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to checkpoint directory.
|
||||
@@ -174,15 +179,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
peft_model = model.unwrap(unwrap_peft=False)
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
if state_dict is None:
|
||||
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
|
||||
if self.coordinator.is_master():
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
state_dict=state_dict,
|
||||
)
|
||||
|
||||
|
||||
@@ -191,8 +198,11 @@ class TorchDDPModel(ModelWrapper):
|
||||
super().__init__(module)
|
||||
self.module = DDP(module, *args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
return self.module.module
|
||||
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
|
||||
model = self.module.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
return model
|
||||
|
||||
|
||||
class TorchDDPPlugin(DPPluginBase):
|
||||
|
Reference in New Issue
Block a user