mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[checkpointio] support load-pin overlap (#6177)
* [checkpointio] support load-pin overlap * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [test] add conftest --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -255,8 +255,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
fsdp_state_dict = {}
|
||||
for shard_file in checkpoint_files:
|
||||
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
|
||||
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
|
||||
fsdp_state_dict.update(state_dict)
|
||||
|
||||
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
|
||||
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
||||
@@ -388,11 +388,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
# Load param
|
||||
fsdp_optim_state = {}
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
for shard_file in checkpoint_files:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
state_dict_shard = load_flat(shard_file, seperator=".")
|
||||
else:
|
||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
|
||||
fsdp_optim_state.update(state_dict_shard)
|
||||
|
||||
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
||||
|
Reference in New Issue
Block a user