[checkpointio] support load-pin overlap (#6177)

* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2025-01-07 16:16:04 +08:00
committed by GitHub
parent 479067e9bc
commit ee81366cac
6 changed files with 56 additions and 32 deletions

View File

@@ -24,8 +24,8 @@ from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_shards,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
@@ -276,13 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():