From 56fe130b156bdf1242e6ee21f18a37b22948c4f5 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Sat, 1 Mar 2025 19:04:14 +0800 Subject: [PATCH] [hotfix] fix lora load (#6231) * [hotfix] fix lora load * [hotfix] fix hp load * accelerate deepseek loading --- .../training_scripts/lora_finetune.py | 2 +- colossalai/booster/plugin/gemini_plugin.py | 14 +-- .../booster/plugin/hybrid_parallel_plugin.py | 3 +- colossalai/booster/plugin/torch_ddp_plugin.py | 3 +- .../booster/plugin/torch_fsdp_plugin.py | 14 +-- .../checkpoint_io/general_checkpoint_io.py | 10 +- .../hybrid_parallel_checkpoint_io.py | 22 ++-- colossalai/checkpoint_io/moe_checkpoint.py | 9 +- colossalai/checkpoint_io/utils.py | 6 ++ colossalai/interface/model.py | 101 +++++++++++++++++- 10 files changed, 146 insertions(+), 38 deletions(-) diff --git a/applications/ColossalChat/examples/training_scripts/lora_finetune.py b/applications/ColossalChat/examples/training_scripts/lora_finetune.py index 851ad6a2d..4045556d7 100644 --- a/applications/ColossalChat/examples/training_scripts/lora_finetune.py +++ b/applications/ColossalChat/examples/training_scripts/lora_finetune.py @@ -257,7 +257,7 @@ def train(args) -> None: ) torch.set_default_dtype(torch.float) - booster.load_model(model, args.pretrained) + booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8) coordinator.print_on_master( f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4b1224c68..a81f9b05d 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -85,11 +85,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO): if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) for k, v in state_dict.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - state_dict[k] = self.pinned_state_dicts[id(model)][k] + self.pinned_state_dicts[hash(model)][k].copy_(v) + state_dict[k] = self.pinned_state_dicts[hash(model)][k] writer = save(checkpoint, state_dict) self.async_writers.append(writer) else: @@ -172,9 +172,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO): Path(checkpoint_path).mkdir(parents=True, exist_ok=True) if use_async and self.coordinator.is_master(): - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = model.state_dict_shard( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702..1e0f7be24 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,6 +26,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.model import PeftUnwrapMixin from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed @@ -225,7 +226,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): if isinstance(model, DDP): model = model.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model def _force_wait_all_gather(self): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index e74b1a959..9cb5adf01 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface.model import PeftUnwrapMixin from colossalai.logging import get_dist_logger from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device @@ -201,7 +202,7 @@ class TorchDDPModel(ModelWrapper): def unwrap(self, unwrap_peft: bool = True) -> nn.Module: model = self.module.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index d713203fe..6e652e549 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -103,11 +103,11 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state) for k, v in full_model_state.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - full_model_state[k] = self.pinned_state_dicts[id(model)][k] + self.pinned_state_dicts[hash(model)][k].copy_(v) + full_model_state[k] = self.pinned_state_dicts[hash(model)][k] writer = save(checkpoint, full_model_state) self.async_writers.append(writer) else: @@ -186,9 +186,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): state_dict = model.unwrap().state_dict() if use_async and self.coordinator.is_master(): - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = utils.shard_model_checkpoint( diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 3e600c94d..5dfb09248 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -60,9 +60,9 @@ class GeneralCheckpointIO(CheckpointIO): if use_async: from colossalai.utils.safetensors import move_and_save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) + writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)]) self.async_writers.append(writer) else: # save the checkpoint @@ -234,7 +234,7 @@ class GeneralCheckpointIO(CheckpointIO): index_file = CheckpointIndexFile(checkpoint_path) if use_async: - pinned_state_dict = self.pinned_state_dicts.get(id(model), None) + pinned_state_dict = self.pinned_state_dicts.get(hash(model), None) total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint_path, @@ -243,7 +243,7 @@ class GeneralCheckpointIO(CheckpointIO): is_master=True, pinned_state_dict=pinned_state_dict, ) - self.pinned_state_dicts[id(model)] = new_pinned_state_dict + self.pinned_state_dicts[hash(model)] = new_pinned_state_dict self.async_writers.extend(writers) else: # Save shards of optimizer states. diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5de32e666..9d9726352 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -249,9 +249,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Only devices with tp_rank == 0 are responsible for model saving. control_saving = self.tp_rank == 0 and self.sp_rank == 0 if control_saving and use_async: - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = HybridParallelCheckpointIO._model_sharder( @@ -789,11 +789,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) for name, param in state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = self.pinned_state_dicts[id(model)][name] + self.pinned_state_dicts[hash(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[hash(model)][name] writer = save(path=checkpoint, state_dict=state_dict) self.async_writers.append(writer) else: @@ -811,11 +811,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict) for name, param in complete_state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] + self.pinned_state_dicts[hash(model)][name].copy_(param) + complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name] writer = save(path=checkpoint, state_dict=complete_state_dict) self.async_writers.append(writer) else: diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 586c7863f..85e36f7c6 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -701,15 +701,18 @@ class MoECheckpointIO(HybridParallelCheckpointIO): all_param = None # gather param from every ep rank # dist.all_gather(all_param, param, group=ep_group) - dist.gather(param, all_param, group=ep_group) + dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group) if ep_rank == 0: all_param = torch.cat(all_param, dim=0) state_dict[name] = all_param.cpu() if self.pp_size > 1: if self.dp_rank == 0: - out = [None for _ in range(self.pp_size)] - dist.gather_object(state_dict, out, group=self.pp_group) + if self.pp_rank == 0: + out = [None for _ in range(self.pp_size)] + else: + out = None + dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group) if self.pp_rank == 0: new_state_dict = {} for o in out: diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 2d826bd15..4b36dbe00 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -20,6 +20,7 @@ from torch.optim import Optimizer from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from colossalai.accelerator import get_accelerator +from colossalai.interface.model import PeftUnwrapMixin from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model except ImportError: return + if isinstance(model, PeftUnwrapMixin): + model = model.base_model if not isinstance(model, PreTrainedModel): return @@ -692,6 +695,9 @@ def load_state_dict_into_model( state_dict (dict): a dict containing parameters and persistent buffers. """ + if isinstance(model, PeftUnwrapMixin): + state_dict = model.patch_state_dict(state_dict) + model = model.base_model if not isinstance(state_dict, Mapping): raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index d112c2723..8dbd15c63 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -1,5 +1,102 @@ +import re +from typing import Dict, Set + +import torch import torch.nn as nn -from peft import PeftModel +from peft import PeftModel, PeftType + + +def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"): + config = model.peft_config[adapter_name] + if config.peft_type != PeftType.LORA: + raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.") + # to_return = lora_state_dict(model, bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` + # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP + bias = config.bias + if bias == "none": + to_return = {k for k in names if "lora_" in k} + elif bias == "all": + to_return = {k for k in names if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = set() + for k in names: + if "lora_" in k: + to_return.add(k) + bias_name = k.split("lora_")[0] + "bias" + if bias_name in names: + to_return.add(bias_name) + else: + raise NotImplementedError + to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))} + if config.use_dora: + # Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a + # ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since + # we want the state_dict format not to change, we remove the "weight" part. + new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight" + + def renamed_dora_weights(k): + if k.endswith(new_dora_suffix): + k = k[:-7] # remove ".weight" + return k + + to_return = {renamed_dora_weights(k) for k in to_return} + + to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return} + return to_return + + +class PeftUnwrapMixin: + def __init__(self, peft_model: PeftModel): + self.base_model = peft_model.get_base_model() + # peft does not affect buffers + self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters())) + potential_lora_weights = set() + for n in self.lora_layers: + potential_lora_weights.add(f"{n}.weight") + potential_lora_weights.add(f"{n}.bias") + self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights} + self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()} + + def named_parameters(self): + for n, p in self.base_model.named_parameters(): + if n in self.lora_param_to_origin_param: + n = self.lora_param_to_origin_param[n] + yield n, p + + def named_buffers(self): + return self.base_model.named_buffers() + + @property + def _modules(self): + return self.base_model._modules + + @property + def _non_persistent_buffers_set(self): + return self.base_model._non_persistent_buffers_set + + def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]): + new_state_dict = {} + for k, v in state_dict.items(): + if k in self.origin_param_to_lora_param: + k = self.origin_param_to_lora_param[k] + new_state_dict[k] = v + return new_state_dict + + def state_dict(self): + state_dict = {} + for k, v in self.base_model.state_dict().items(): + if k in self.lora_param_to_origin_param: + k = self.lora_param_to_origin_param[k] + state_dict[k] = v + return state_dict + + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + state_dict = self.patch_state_dict(state_dict) + self.base_model.load_state_dict(state_dict, strict=strict, assign=assign) + + def __hash__(self): + return hash(self.base_model) class ModelWrapper(nn.Module): @@ -23,7 +120,7 @@ class ModelWrapper(nn.Module): else: model = self.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model def forward(self, *args, **kwargs):