[moe] solve dp axis issue

This commit is contained in:
botbw 2024-07-26 07:32:19 +00:00 committed by Hongxin Liu
parent 65daa87627
commit d1d1ab871e
2 changed files with 13 additions and 21 deletions

View File

@ -1375,15 +1375,15 @@ class HybridParallelPlugin(PipelinePluginBase):
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
Returns:`
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset,
num_replicas=self.pg_mesh.size(self.dp_axis),
rank=self.pg_mesh.coordinate(self.dp_axis),
num_replicas=self.dp_group.size(),
rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()),
shuffle=shuffle,
)

View File

@ -211,7 +211,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
moe_dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
) -> None:
@ -266,20 +266,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 0, 1
self.__moe_pg_mesh = ProcessGroupMesh(
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
)
if moe_dp_outside:
self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.moe_dp_axis, self.ep_axis = 1, 2
self.__moe_pg_mesh = ProcessGroupMesh(
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
)
self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
@ -323,10 +315,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
raise NotImplementedError()
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis])
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
@ -420,7 +412,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group