[hotfix] fix lora load (#6231)

* [hotfix] fix lora load

* [hotfix] fix hp load

* accelerate deepseek loading
This commit is contained in:
Hongxin Liu
2025-03-01 19:04:14 +08:00
committed by GitHub
parent f32861ccc5
commit 56fe130b15
10 changed files with 146 additions and 38 deletions

View File

@@ -103,11 +103,11 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
if use_async:
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
self.pinned_state_dicts[hash(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
@@ -186,9 +186,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
state_dict = model.unwrap().state_dict()
if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = utils.shard_model_checkpoint(