mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[hotfix] fix lora load (#6231)
* [hotfix] fix lora load * [hotfix] fix hp load * accelerate deepseek loading
This commit is contained in:
@@ -103,11 +103,11 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
|
||||
for k, v in full_model_state.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
|
||||
self.pinned_state_dicts[hash(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
|
||||
writer = save(checkpoint, full_model_state)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
@@ -186,9 +186,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = utils.shard_model_checkpoint(
|
||||
|
Reference in New Issue
Block a user