1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-04-30 12:45:33 +00:00

[hotfix] fix hybrid checkpointio for sp+dp ()

* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update build_on_pr.yml

* Update test_zerobubble_pp.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111 2025-02-06 17:21:04 +08:00 committed by GitHub
parent ca0aa2365d
commit 17062c83b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 35 additions and 30 deletions
.github/workflows
colossalai/booster/plugin
tests
test_pipeline/test_schedule
test_shardformer/test_model

View File

@ -199,7 +199,7 @@ jobs:
fi fi
- name: Upload test coverage artifact - name: Upload test coverage artifact
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: report name: report
path: report/ path: report/

View File

@ -1188,6 +1188,15 @@ class HybridParallelPlugin(PipelinePluginBase):
else: else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group, tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group, sequence_parallel_process_group=self.sp_group,
@ -1298,19 +1307,11 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1 self.dp_size == 1 and self.pp_size == 1
) )
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule( model = HybridParallelModule(
model, model,
precision=self.precision, precision=self.precision,
shard_config=self.shard_config, shard_config=self.shard_config,
dp_group=dp_group, dp_group=self.mixed_dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
sp_group=self.sp_group, sp_group=self.sp_group,
use_ddp=use_ddp, use_ddp=use_ddp,
@ -1359,7 +1360,7 @@ class HybridParallelPlugin(PipelinePluginBase):
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info, param_info=param_info,
dp_process_group=dp_group, dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
pp_process_group=self.pp_group, pp_process_group=self.pp_group,
verbose=True, verbose=True,
@ -1488,7 +1489,9 @@ class HybridParallelPlugin(PipelinePluginBase):
) )
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) return HybridParallelCheckpointIO(
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
)
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert ( assert (

View File

@ -351,6 +351,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else: else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
self.dp_size = dist.get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group
self.use_fp8 = use_fp8 self.use_fp8 = use_fp8
self.shard_config = ShardConfig( self.shard_config = ShardConfig(
@ -404,7 +412,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def get_checkpoint_io(self) -> MoECheckpointIO: def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO( return MoECheckpointIO(
self.dp_group, self.mixed_dp_group,
self.pp_group, self.pp_group,
self.tp_group, self.tp_group,
self.sp_group, self.sp_group,
@ -435,12 +443,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
and self.sequence_parallelism_mode == "all_to_all" and self.sequence_parallelism_mode == "all_to_all"
) )
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group
if use_ddp: if use_ddp:
self.logger.warning( self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated", f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
@ -448,7 +450,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
self.ddp_config["find_unused_parameters"] = True self.ddp_config["find_unused_parameters"] = True
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group): if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError( raise ValueError(
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this" f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
) )
@ -457,7 +459,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
module=model, module=model,
precision=self.precision, precision=self.precision,
shard_config=self.shard_config, shard_config=self.shard_config,
dp_group=dp_group, dp_group=self.mixed_dp_group,
tp_group=self.tp_group, tp_group=self.tp_group,
sp_group=self.sp_group, sp_group=self.sp_group,
use_ddp=use_ddp, use_ddp=use_ddp,
@ -507,7 +509,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info, param_info=param_info,
dp_process_group=dp_group, dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,
pp_process_group=self.pp_group, pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group, moe_dp_group=self.moe_dp_group,

View File

@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output) parallel_optimizer.backward(parallel_output)
parallel_optimizer.step() parallel_optimizer.step()
parallel_optimizer.zero_grad() parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group) dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
# =================================================================================== # ===================================================================================
# run normal model with all dp(different) inputs # run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)] all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0 torch_output_sum = 0
for input_data_ in all_inputs: for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output) parallel_optimizer.backward(parallel_output)
parallel_optimizer.step() parallel_optimizer.step()
parallel_optimizer.zero_grad() parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group) dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
# =================================================================================== # ===================================================================================
# run normal model with all dp(different) inputs # run normal model with all dp(different) inputs
all_inputs = [input_embeddings.clone() for _ in range(dp_size)] all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0 torch_output_sum = 0
for input_data_ in all_inputs: for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()

View File

@ -125,12 +125,12 @@ def run_deepseek_commom(parallel_config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output) parallel_optimizer.backward(parallel_output)
parallel_optimizer.step() parallel_optimizer.step()
parallel_optimizer.zero_grad() parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group) dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
# =================================================================================== # ===================================================================================
# run normal model with all dp(different) inputs # run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0 torch_output_sum = 0
for input_data_ in all_inputs: for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()

View File

@ -118,12 +118,12 @@ def run_mixtral_commom(config: Tuple[int, ...]):
parallel_optimizer.backward(parallel_output) parallel_optimizer.backward(parallel_output)
parallel_optimizer.step() parallel_optimizer.step()
parallel_optimizer.zero_grad() parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group) dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
# =================================================================================== # ===================================================================================
# run normal model with all dp(different) inputs # run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
torch_output_sum = 0 torch_output_sum = 0
for input_data_ in all_inputs: for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()