[hotfix] fix hybrid checkpointio for sp+dp (#6184)

* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update build_on_pr.yml

* Update test_zerobubble_pp.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2025-02-06 17:21:04 +08:00
committed by GitHub
parent ca0aa2365d
commit 17062c83b9
6 changed files with 35 additions and 30 deletions

View File

@@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
@@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()