diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 8d96ca1b9..b05cb660b 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -199,7 +199,7 @@ jobs: fi - name: Upload test coverage artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: report path: report/ diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bc9425a0b..62046bc36 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1188,6 +1188,15 @@ class HybridParallelPlugin(PipelinePluginBase): else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + # sync gradients across DP * SP ranks + # sync gradients across DP * SP ranks + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): + self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(self.mixed_dp_group) + else: + self.mixed_dp_group = self.dp_group + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, sequence_parallel_process_group=self.sp_group, @@ -1298,19 +1307,11 @@ class HybridParallelPlugin(PipelinePluginBase): use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 and self.pp_size == 1 ) - # sync gradients across DP * SP ranks - # sync gradients across DP * SP ranks - # Apply Hybrid ZeRO across DP * SP ranks - if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): - dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - self.dp_size = get_world_size(dp_group) - else: - dp_group = self.dp_group model = HybridParallelModule( model, precision=self.precision, shard_config=self.shard_config, - dp_group=dp_group, + dp_group=self.mixed_dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -1359,7 +1360,7 @@ class HybridParallelPlugin(PipelinePluginBase): model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=dp_group, + dp_process_group=self.mixed_dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, @@ -1488,7 +1489,9 @@ class HybridParallelPlugin(PipelinePluginBase): ) def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) + return HybridParallelCheckpointIO( + self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage + ) def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert ( diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 6937b8d74..35f076e02 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -351,6 +351,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + + # sync gradients across DP * SP ranks + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) + self.dp_size = dist.get_world_size(self.mixed_dp_group) + else: + self.mixed_dp_group = self.dp_group + self.use_fp8 = use_fp8 self.shard_config = ShardConfig( @@ -404,7 +412,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( - self.dp_group, + self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, @@ -435,12 +443,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): and self.sequence_parallelism_mode == "all_to_all" ) - # sync gradients across DP * SP ranks - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) - else: - dp_group = self.dp_group - if use_ddp: self.logger.warning( f"Will have to check all params are used in pytorch DDP since not all experts are always activated", @@ -448,7 +450,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) self.ddp_config["find_unused_parameters"] = True - if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group): raise ValueError( f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this" ) @@ -457,7 +459,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=dp_group, + dp_group=self.mixed_dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -507,7 +509,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=dp_group, + dp_process_group=self.mixed_dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, moe_dp_group=self.moe_dp_group, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index a01b75eee..67b05f027 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [input_embeddings.clone() for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() @@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [input_embeddings.clone() for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 4b92dbdee..20dfa78c6 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -125,12 +125,12 @@ def run_deepseek_commom(parallel_config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 940c66cf6..b69113072 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -118,12 +118,12 @@ def run_mixtral_commom(config: Tuple[int, ...]): parallel_optimizer.backward(parallel_output) parallel_optimizer.step() parallel_optimizer.zero_grad() - dist.all_reduce(parallel_output, group=plugin.dp_group) + dist.all_reduce(parallel_output, group=plugin.mixed_dp_group) # =================================================================================== # run normal model with all dp(different) inputs all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)] - dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) + dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group) torch_output_sum = 0 for input_data_ in all_inputs: torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()