mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[checkpointio] support load-pin overlap (#6177)
* [checkpointio] support load-pin overlap * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [test] add conftest --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,8 +24,8 @@ from colossalai.checkpoint_io.utils import (
|
||||
get_optimizer_base_filenames,
|
||||
get_shard_filename,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_state_dict_shards,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
@@ -276,13 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
state_dict = load_flat(shard_file)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
|
||||
# shard state dict
|
||||
for param_idx, state in state_dict.items():
|
||||
for k, v in state.items():
|
||||
|
Reference in New Issue
Block a user