diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5a9bae479..fcb747814 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1375,15 +1375,15 @@ class HybridParallelPlugin(PipelinePluginBase): kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in `DataLoader `_. - Returns: + Returns:` :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() distributed_sampler_cls = distributed_sampler_cls or DistributedSampler sampler = distributed_sampler_cls( dataset, - num_replicas=self.pg_mesh.size(self.dp_axis), - rank=self.pg_mesh.coordinate(self.dp_axis), + num_replicas=self.dp_group.size(), + rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()), shuffle=shuffle, ) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 7a16a1737..b3415af0e 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -211,7 +211,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, - dp_outside: bool = True, + moe_dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, ) -> None: @@ -266,20 +266,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - if dp_outside: - self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) - self.moe_dp_axis, self.ep_axis = 0, 1 - self.__moe_pg_mesh = ProcessGroupMesh( - self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size - ) + if moe_dp_outside: + self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 + self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size) else: - self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) - self.moe_dp_axis, self.ep_axis = 1, 2 - self.__moe_pg_mesh = ProcessGroupMesh( - self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size - ) + self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) self.stage_manager = None self.schedule = None @@ -323,10 +315,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): raise NotImplementedError() self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis]) self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) - self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis) - self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis) + self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis) + self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: @@ -420,7 +412,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): # sync gradients across DP * SP ranks if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) else: dp_group = self.dp_group