[checkpointio] fix for async io (#6189)

This commit is contained in:
flybird11111 2025-02-14 17:34:13 +08:00 committed by GitHub
parent 5ff5323538
commit ce0ec40811
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -315,12 +315,13 @@ def async_save_state_dict_shards(
checkpoint_file_path = os.path.join(checkpoint, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file)
if state_preprocess: if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".") state_dict, metadata = _flatten_optim_state_dict(state_dict=shard, seperator=".")
else: else:
state_dict = shard state_dict = shard
metadata = None
# Only save on master rank. # Only save on master rank.
writer = save(checkpoint_file_path, state_dict=state_dict) writer = save(checkpoint_file_path, state_dict=state_dict, metadata=metadata)
writers.append(writer) writers.append(writer)
shard_filenames.append(shard_file) shard_filenames.append(shard_file)
del shard del shard
@ -377,9 +378,10 @@ def async_move_save_state_dict_shards(
checkpoint_file_path = os.path.join(checkpoint, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file)
if state_preprocess: if state_preprocess:
state_dict, _ = _flatten_optim_state_dict(state_dict=shard) state_dict, metadata = _flatten_optim_state_dict(state_dict=shard)
else: else:
state_dict = shard state_dict = shard
metadata = None
if pinned_state_dict is not None: if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()} sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
@ -388,7 +390,7 @@ def async_move_save_state_dict_shards(
returned_state_dict.update(sub_pinned_state_dict) returned_state_dict.update(sub_pinned_state_dict)
# Only save on master rank. # Only save on master rank.
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict) writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict, metadata)
writers.append(writer) writers.append(writer)
shard_filenames.append(shard_file) shard_filenames.append(shard_file)
del shard del shard