[checkpointio] support load-pin overlap (#6177)

* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2025-01-07 16:16:04 +08:00
committed by GitHub
parent 479067e9bc
commit ee81366cac
6 changed files with 56 additions and 32 deletions

View File

@@ -20,7 +20,7 @@ from colossalai.checkpoint_io.utils import (
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_state_dict_shards,
save_config_file,
save_state_dict,
save_state_dict_shards,
@@ -29,7 +29,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils.safetensors import load_flat
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@@ -350,11 +349,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in load_state_dict_shards(
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
):
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)