[hotfix] fix lora load (#6231)

* [hotfix] fix lora load

* [hotfix] fix hp load

* accelerate deepseek loading
This commit is contained in:
Hongxin Liu
2025-03-01 19:04:14 +08:00
committed by GitHub
parent f32861ccc5
commit 56fe130b15
10 changed files with 146 additions and 38 deletions

View File

@@ -60,9 +60,9 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async:
from colossalai.utils.safetensors import move_and_save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
self.async_writers.append(writer)
else:
# save the checkpoint
@@ -234,7 +234,7 @@ class GeneralCheckpointIO(CheckpointIO):
index_file = CheckpointIndexFile(checkpoint_path)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
@@ -243,7 +243,7 @@ class GeneralCheckpointIO(CheckpointIO):
is_master=True,
pinned_state_dict=pinned_state_dict,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# Save shards of optimizer states.

View File

@@ -249,9 +249,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Only devices with tp_rank == 0 are responsible for model saving.
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if control_saving and use_async:
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
@@ -789,11 +789,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name]
self.pinned_state_dicts[hash(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer)
else:
@@ -811,11 +811,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
for name, param in complete_state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
self.pinned_state_dicts[hash(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict)
self.async_writers.append(writer)
else:

View File

@@ -701,15 +701,18 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
all_param = None
# gather param from every ep rank
# dist.all_gather(all_param, param, group=ep_group)
dist.gather(param, all_param, group=ep_group)
dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.gather_object(state_dict, out, group=self.pp_group)
if self.pp_rank == 0:
out = [None for _ in range(self.pp_size)]
else:
out = None
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:

View File

@@ -20,6 +20,7 @@ from torch.optim import Optimizer
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from colossalai.accelerator import get_accelerator
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError:
return
if isinstance(model, PeftUnwrapMixin):
model = model.base_model
if not isinstance(model, PreTrainedModel):
return
@@ -692,6 +695,9 @@ def load_state_dict_into_model(
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if isinstance(model, PeftUnwrapMixin):
state_dict = model.patch_state_dict(state_dict)
model = model.base_model
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))