[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)

* [checkpointio] fix hybrid parallel optim checkpoint

* [extension] fix cuda extension

* [checkpointio] fix gemini optimizer checkpoint

* polish code
This commit is contained in:
Hongxin Liu
2024-02-01 16:13:06 +08:00
committed by GitHub
parent c5239840e6
commit ffffc32dc7
5 changed files with 35 additions and 8 deletions

View File

@@ -1,7 +1,7 @@
import copy
from functools import reduce
import logging
import os
from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
@@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
save_param_groups({"param_groups": param_groups}, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
@@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
@@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Only the master rank do the saving.
if self.coordinator.is_master():
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
]
state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)