[checkpointio] support load-pin overlap (#6177)

* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2025-01-07 16:16:04 +08:00
committed by GitHub
parent 479067e9bc
commit ee81366cac
6 changed files with 56 additions and 32 deletions

View File

@@ -18,9 +18,9 @@ from .utils import (
get_optimizer_base_filenames,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_state_dict_shards,
load_states_into_optimizer,
save_config_file,
save_param_groups,
@@ -94,11 +94,7 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
if shard_file.endswith(".safetensors"):
state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
@@ -295,8 +291,7 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)