diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 9625afc1b..9b6eae0d0 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,9 +1,8 @@ import warnings from collections import defaultdict from types import MethodType -from typing import Callable, Optional, OrderedDict, Tuple +from typing import Callable, List, Optional, OrderedDict, Tuple -import numpy as np import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -13,6 +12,8 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.booster.plugin.hybrid_parallel_plugin import ( + PRECISION_TORCH_TYPE, + SUPPORT_SP_MODE, HybridParallelAMPOptimizer, HybridParallelModule, HybridParallelNaiveOptimizer, @@ -22,9 +23,16 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( reinitialize_optimizer, ) from colossalai.checkpoint_io import MoECheckpointIO +from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim from colossalai.nn.optimizer import cast_to_distributed +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig +from colossalai.shardformer.shard.shard_config import ShardConfig from colossalai.tensor.moe_tensor.api import is_moe_tensor @@ -57,7 +65,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): forced_dtype: Optional[torch.dtype] = None, overlap_allgather: bool = False, ): - WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result" + WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result." if not force_overlap_comm and (overlap_communication or partition_grad): raise RuntimeError( WARN_STR @@ -105,130 +113,219 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): class MoeHybridParallelPlugin(HybridParallelPlugin): """ - TODO: add docstring + Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin + Extra Args: + ep_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + force_overlap_comm (bool): For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training. This flag is used to force overlap_communication=True. """ - def __init__(self, ep_size: int, moe_tp_size: int = 1, force_overlap_comm=False, *args, **kwargs) -> None: - if "overlap_communication" not in kwargs: - kwargs["overlap_communication"] = False # default by true in super class - - super().__init__(*args, **kwargs) - - if ep_size <= 1: - raise ValueError("Use HybridParallelPlugin when ep_size <= 1") + def __init__( + self, + tp_size: int, + pp_size: int, + ep_size: int, + sp_size: int = None, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, + enable_sequence_overlap: bool = False, + parallel_output: bool = True, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, + enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, + dp_outside: bool = True, + overlap_p2p: bool = True, + overlap_allgather: bool = False, + force_overlap_comm: bool = False, + ) -> None: + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + if enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + warnings.warn( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all"]: + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"You should not set sp_size when sequence parallelism is not enabled." + self.sp_size = 1 + assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}" + self.moe_dp_size = self.dp_size // ep_size self.ep_size = ep_size - self.moe_tp_size = moe_tp_size + self.tp_size = tp_size + self.pp_size = pp_size + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + if dp_outside: + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + self.moe_dp_axis, self.ep_axis = 0, 1 + self.moe_pg_mesh = ProcessGroupMesh( + self.moe_dp_size, self.ep_size, self.pp_size, self.tp_size, self.sp_size + ) + else: + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + self.moe_dp_axis, self.ep_axis = 1, 2 + self.moe_pg_mesh = ProcessGroupMesh( + self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size + ) - self._init_moe_param_comm() + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager( + self.pg_mesh, + pipeline_axis=self.pp_axis, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, + ) - self.use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( - self.dp_size == 1 - and self.pp_size == 1 - and self.enable_sequence_parallelism - and self.sequence_parallelism_mode == "all_to_all" + if pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, + ) + elif pp_style == "1f1b": + self.schedule = OneForwardOneBackwardSchedule( + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + ) + else: + raise NotImplementedError() + + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) + self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis) + self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, + ep_group=self.ep_group, + moe_dp_group=self.moe_dp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, + enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, + gradient_checkpoint_config=gradient_checkpoint_config, + ) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, ) - if self.use_ddp: - warnings.warn( - f"Will have to check all params are used in pytorch DDP since not all experts are always activated" - ) - self.ddp_config["find_unused_parameters"] = True + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) - if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): - # TODO it might make sense to support non-moe with tp on but moe with tp off - raise ValueError( - f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0" - ) - - # set param group in shard config - self.shard_config.ep_group = self.ep_group - self.shard_config.moe_dp_group = self.moe_dp_group - self.shard_config.moe_tp_group = self.moe_tp_group + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, + ) + self.max_norm = max_norm self.force_overlap_comm = force_overlap_comm - def _init_moe_param_comm(self): - world_size = dist.get_world_size() - - if self.enable_sequence_parallelism: - if self.sequence_parallelism_mode == "all_to_all": - # if sequence parallelism is enabled, ep_group reuses sp_group - if self.ep_size != self.sp_size: - raise ValueError( - f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled" - ) - - # since we are reusing sp_group, moe_dp_group will be derived as dp_group - self.moe_dp_size = self.dp_size - self.moe_dp_group = self.dp_group - self.dp_sp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - self.ep_group = self.sp_group - self.moe_tp_group = self.tp_group - else: - raise NotImplementedError( - f"sequence_parallelism_mode={self.sequence_parallelism_mode} is not supported" - ) - - else: - self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size) - - if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size: - raise ValueError( - f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}" - ) - - self.moe_dp_group = None - self.ep_group = None - self.moe_tp_group = None - self.dp_sp_group = self.dp_group - - # create submesh for ep, moe_dp, moe_tp - ranks_by_pp_stage = self.pg_mesh.get_group_along_axis( - [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True - ) - - global_rank = self.pg_mesh.rank - pp_rank = self.pg_mesh.coordinate(self.pp_axis) - - # create groups from submesh - for stage_idx, stage_rank in enumerate(ranks_by_pp_stage): - # axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp - submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size) - - # hardcode here since we only have 3 axis - # moe_dp_group - for ep_idx in range(self.ep_size): - for moe_tp_idx in range(self.moe_tp_size): - moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist() - group = dist.new_group(moe_dp_ranks) - if pp_rank == stage_idx and global_rank in moe_dp_ranks: - assert self.moe_dp_group is None - self.moe_dp_group = group - # ep_group - for moe_dp_idx in range(self.moe_dp_size): - for moe_tp_idx in range(self.moe_tp_size): - ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist() - group = dist.new_group(ep_ranks) - if pp_rank == stage_idx and global_rank in ep_ranks: - assert self.ep_group is None - self.ep_group = group - # moe_tp_group - for moe_dp_idx in range(self.moe_dp_size): - for ep_idx in range(self.ep_size): - moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist() - group = dist.new_group(moe_tp_ranks) - if pp_rank == stage_idx and global_rank in moe_tp_ranks: - assert self.moe_tp_group is None - self.moe_tp_group = group - - if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group): - # NOTE: different tp settings between moe and non moe param are complex to handle - # we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size - raise NotImplementedError( - f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size" - ) - def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage @@ -249,14 +346,37 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): optimizer = cast_to_distributed(optimizer) if not isinstance(model, ModelWrapper): + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.pp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) + if use_ddp: + warnings.warn( + f"Will have to check all params are used in pytorch DDP since not all experts are always activated" + ) + self.ddp_config["find_unused_parameters"] = True + + if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + raise ValueError( + f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0" + ) + + # sync gradients across DP * SP ranks + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + else: + dp_group = self.dp_group + model = HybridParallelModule( module=model, precision=self.precision, shard_config=self.shard_config, - dp_group=self.dp_sp_group, + dp_group=dp_group, tp_group=self.tp_group, sp_group=self.sp_group, - use_ddp=self.use_ddp, + use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, ) @@ -301,7 +421,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): use_pipeline=self.enable_pipeline_parallelism, force_overlap_comm=self.force_overlap_comm, param_info=param_info, - dp_process_group=self.dp_sp_group, + dp_process_group=dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, moe_dp_group=self.moe_dp_group, diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py index 52ea6c22b..a84a30972 100644 --- a/colossalai/shardformer/modeling/deepseek.py +++ b/colossalai/shardformer/modeling/deepseek.py @@ -61,13 +61,10 @@ class EPDeepseekMoE(nn.Module): def __init__(self): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups( - self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup - ): + def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None - assert moe_tp_group is not None self.ep_size = dist.get_world_size(ep_group) self.ep_rank = dist.get_rank(ep_group) @@ -85,16 +82,13 @@ class EPDeepseekMoE(nn.Module): self.moe_dp_group = moe_dp_group self.moe_dp_size = moe_dp_group.size() - # setup global tp group + # setup tp group self.tp_group = tp_group - - # setup moe tp group - self.moe_tp_group = moe_tp_group - if self.moe_tp_group.size() > 1: + if self.tp_group.size() > 1: for expert in held_experts: - expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group) - expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group) - expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group) + expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group) + expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group) + expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -105,7 +99,6 @@ class EPDeepseekMoE(nn.Module): tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, - moe_tp_group: ProcessGroup, *args, **kwargs, ) -> "EPDeepseekMoE": @@ -113,7 +106,7 @@ class EPDeepseekMoE(nn.Module): if module.__class__.__name__ == "DeepseekMLP": return module module.__class__ = EPDeepseekMoE - module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group) + module.setup_process_groups(tp_group, moe_dp_group, ep_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 9148a9fba..029ac36cd 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -53,13 +53,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): def __init__(self, *args, **kwargs): raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}") - def setup_process_groups( - self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup - ): + def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup): assert tp_group is not None assert moe_dp_group is not None assert ep_group is not None - assert moe_tp_group is not None # setup ep group self.ep_size = dist.get_world_size(ep_group) @@ -81,14 +78,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): # setup global tp group self.tp_group = tp_group - - # setup moe tp group - self.moe_tp_group = moe_tp_group - if self.moe_tp_group.size() > 1: + if self.tp_group.size() > 1: for expert in held_experts: - expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.moe_tp_group) - expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group) - expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group) + expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group) + expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group) + expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group) for p in self.experts.parameters(): set_moe_tensor_ep_group(p, ep_group) @@ -99,14 +93,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, - moe_tp_group: ProcessGroup, *args, **kwargs, ) -> "EPMixtralSparseMoeBlock": # TODO: better init LazyInitContext.materialize(module) module.__class__ = EPMixtralSparseMoeBlock - module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group) + module.setup_process_groups(tp_group, moe_dp_group, ep_group) return module def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index 963bd9d67..d729a4ecc 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -154,7 +154,6 @@ class DeepseekPolicy(Policy): "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, - "moe_tp_group": self.shard_config.moe_tp_group, }, ) ], diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 4de982f44..85895820e 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -155,7 +155,6 @@ class MixtralPolicy(Policy): "ep_group": self.shard_config.ep_group, "tp_group": self.shard_config.tensor_parallel_process_group, "moe_dp_group": self.shard_config.moe_dp_group, - "moe_tp_group": self.shard_config.moe_tp_group, }, ) ], diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index f12c78526..163d7a7bb 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -50,7 +50,6 @@ class ShardConfig: # for moe related moe_dp_group: Optional[ProcessGroup] = None ep_group: Optional[ProcessGroup] = None - moe_tp_group: Optional[ProcessGroup] = None # pipeline_parallel_size: int # data_parallel_size: int diff --git a/tests/test_moe/test_deepseek_layer.py b/tests/test_moe/test_deepseek_layer.py index e633cdd07..d18ba2eac 100644 --- a/tests/test_moe/test_deepseek_layer.py +++ b/tests/test_moe/test_deepseek_layer.py @@ -47,7 +47,6 @@ def check_deepseek_moe_layer(): model, ep_group=plugin.ep_group, moe_dp_group=plugin.moe_dp_group, - moe_tp_group=plugin.moe_tp_group, tp_group=plugin.tp_group, ) ep_output = model(x) diff --git a/tests/test_moe/test_mixtral_layer.py b/tests/test_moe/test_mixtral_layer.py index 5d9ca622a..bc41ac4f3 100644 --- a/tests/test_moe/test_mixtral_layer.py +++ b/tests/test_moe/test_mixtral_layer.py @@ -42,7 +42,6 @@ def check_mixtral_moe_layer(): ep_group=plugin.ep_group, tp_group=plugin.tp_group, moe_dp_group=plugin.moe_dp_group, - moe_tp_group=plugin.moe_tp_group, ) ep_output, ep_logits = model(x) assert_close(orig_logits, ep_logits) diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 5e6c0bf1c..709963613 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -24,24 +24,28 @@ NUM_HEADS = 4 TOP_K = 2 -CHECKED_CONFIG = [ # FOR_WORLD=8 - (2, 1, 1, 4, 1), - (4, 1, 1, 2, 1), - (4, 1, 1, 1, 1), - (2, 1, 2, 1, 1), +CHECKED_CONFIG = [ # FOR_WORLD=4 + (1, 4, 1, 1, 1), + (1, 1, 4, 1, 1), + (1, 1, 1, 4, 1), + (1, 1, 1, 1, 4), + (0, 1, 4, 1, 1), + (0, 1, 1, 4, 1), + (0, 1, 1, 1, 4), + (1, 2, 1, 1, 1), ] @parameterize( "config", [ - (2, 1, 2, 1, 1), - # (2, 1, 1, 2, 1), - # (2, 1, 1, 1, 2), + (1, 2, 2, 1, 1), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), ], ) def run_zero_with_original_model(config: Tuple[int, ...]): - ep_size, stage, pp_size, tp_size, sp_size = config + stage, ep_size, pp_size, tp_size, sp_size = config world_size = dist.get_world_size() rank = dist.get_rank() dtype, precision = torch.float16, "fp16" @@ -53,7 +57,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]): tp_size=tp_size, sp_size=sp_size, ep_size=ep_size, - moe_tp_size=tp_size, zero_stage=stage, enable_sequence_parallelism=sp_size > 1, sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 419679797..a3e201b67 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -25,24 +25,28 @@ NUM_HEADS = 4 TOP_K = 1 CHECKED_CONFIG = [ # FOR WORLD=4 - (2, 1, 2, 2, 1), - (2, 1, 1, 2, 1), - (2, 1, 4, 1, 1), - (4, 1, 1, 1, 1), - (4, 1, 1, 2, 1), - (4, 1, 2, 1, 1), - (2, 1, 2, 1, 1), + (0, 1, 4, 1, 1), + (0, 1, 1, 4, 1), + (0, 1, 1, 1, 4), + (1, 4, 1, 1, 1), + (1, 1, 4, 1, 1), + (1, 1, 1, 4, 1), + (1, 1, 1, 1, 4), + (1, 2, 1, 1, 1), ] @parameterize( "config", [ - (2, 1, 1, 2, 1), + (1, 2, 2, 1, 1), + (1, 2, 1, 2, 1), + (1, 2, 1, 1, 2), + (0, 2, 1, 1, 1), ], ) def run_zero_with_original_model(config: Tuple[int, ...]): - ep_size, stage, pp_size, tp_size, sp_size = config + stage, ep_size, pp_size, tp_size, sp_size = config world_size = dist.get_world_size() rank = dist.get_rank() dtype, precision = torch.float16, "fp16" @@ -54,7 +58,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]): tp_size=tp_size, sp_size=sp_size, ep_size=ep_size, - moe_tp_size=tp_size, zero_stage=stage, enable_sequence_parallelism=sp_size > 1, sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,