mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-13 21:55:46 +00:00
[moe] solve dp axis issue
This commit is contained in:
parent
65daa87627
commit
d1d1ab871e
@ -1375,15 +1375,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
Returns:`
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
|
||||
sampler = distributed_sampler_cls(
|
||||
dataset,
|
||||
num_replicas=self.pg_mesh.size(self.dp_axis),
|
||||
rank=self.pg_mesh.coordinate(self.dp_axis),
|
||||
num_replicas=self.dp_group.size(),
|
||||
rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()),
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
|
@ -211,7 +211,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
dp_outside: bool = True,
|
||||
moe_dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
) -> None:
|
||||
@ -266,20 +266,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
self.moe_dp_axis, self.ep_axis = 0, 1
|
||||
self.__moe_pg_mesh = ProcessGroupMesh(
|
||||
self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size
|
||||
)
|
||||
if moe_dp_outside:
|
||||
self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
|
||||
self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||
else:
|
||||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||
self.moe_dp_axis, self.ep_axis = 1, 2
|
||||
self.__moe_pg_mesh = ProcessGroupMesh(
|
||||
self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size
|
||||
)
|
||||
self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
@ -323,10 +315,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
raise NotImplementedError()
|
||||
|
||||
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
|
||||
self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis])
|
||||
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
|
||||
self.moe_dp_group = self.__moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
||||
self.ep_group = self.__moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
||||
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||
else:
|
||||
@ -420,7 +412,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user