[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -44,12 +44,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose)
self.global_dp_group = global_dp_group
self.global_dp_rank = dist.get_rank(global_dp_group)
self.global_dp_size = dist.get_world_size(global_dp_group)
@@ -158,7 +159,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
@@ -415,7 +416,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
# rank 3 & 4 save nothing
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 and self.sp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO