mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[hotfix] fix hybrid checkpointio for sp+dp (#6184)
* Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update build_on_pr.yml * Update test_zerobubble_pp.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||
parallel_optimizer.backward(parallel_output)
|
||||
parallel_optimizer.step()
|
||||
parallel_optimizer.zero_grad()
|
||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
||||
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
|
||||
|
||||
# ===================================================================================
|
||||
# run normal model with all dp(different) inputs
|
||||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
@@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
||||
parallel_optimizer.backward(parallel_output)
|
||||
parallel_optimizer.step()
|
||||
parallel_optimizer.zero_grad()
|
||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
||||
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
|
||||
|
||||
# ===================================================================================
|
||||
# run normal model with all dp(different) inputs
|
||||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
|
||||
torch_output_sum = 0
|
||||
for input_data_ in all_inputs:
|
||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||
|
Reference in New Issue
Block a user