[hotfix] fix the bug of repeatedly storing param group (#4951)

This commit is contained in:
Baizhou Zhang
2023-10-31 14:48:01 +08:00
committed by GitHub
parent be82b5d4ca
commit c040d70aa0
2 changed files with 10 additions and 9 deletions

View File

@@ -119,11 +119,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
if self.coordinator.is_master():
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0