mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[checkpointio] support load-pin overlap (#6177)
* [checkpointio] support load-pin overlap * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [test] add conftest --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
create_pinned_state_dict,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
load_state_dict_shards,
|
||||
save_config_file,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
@@ -29,7 +29,6 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
@@ -350,11 +349,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# Load optimizer states from shard files under checkpoint path.
|
||||
# For each file, only load the states managed by current process.
|
||||
for shard_file in checkpoint_files:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
state_dict_shard = load_flat(shard_file)
|
||||
else:
|
||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
for state_dict_shard in load_state_dict_shards(
|
||||
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
|
||||
):
|
||||
if not low_cpu_mem_mode:
|
||||
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
|
||||
optimizer.load_param_states(state_dict_shard)
|
||||
|
Reference in New Issue
Block a user