mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)
* [checkpointio] fix hybrid parallel optim checkpoint * [extension] fix cuda extension * [checkpointio] fix gemini optimizer checkpoint * polish code
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
from functools import reduce
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
@@ -445,7 +445,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
@@ -504,7 +508,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
save_param_groups({"param_groups": param_groups}, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
@@ -718,7 +726,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": local_states}
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
state_dict = {"param_groups": param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
@@ -729,7 +741,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# Only the master rank do the saving.
|
||||
if self.coordinator.is_master():
|
||||
state_dict = {"param_groups": optimizer.param_info["param_groups"], "state": dict()}
|
||||
param_groups = [
|
||||
{**group, "params": group_info["params"]}
|
||||
for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"])
|
||||
]
|
||||
state_dict = {"param_groups": param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
Reference in New Issue
Block a user