[hotfix] fix lora load (#6231)

* [hotfix] fix lora load

* [hotfix] fix hp load

* accelerate deepseek loading
This commit is contained in:
Hongxin Liu
2025-03-01 19:04:14 +08:00
committed by GitHub
parent f32861ccc5
commit 56fe130b15
10 changed files with 146 additions and 38 deletions

View File

@@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
@@ -201,7 +202,7 @@ class TorchDDPModel(ModelWrapper):
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
model = self.module.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
model = PeftUnwrapMixin(model)
return model