From 7bedd03739acea830cf283c29c3d5ed38277b291 Mon Sep 17 00:00:00 2001 From: hxwang Date: Thu, 25 Jul 2024 09:49:57 +0000 Subject: [PATCH] [moe] remove force_overlap_comm flag and add warning instead --- .../plugin/moe_hybrid_parallel_plugin.py | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index beac2d037..b49b886a0 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -42,7 +42,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): optimizer: Optimizer, model: Module, use_pipeline: bool, - force_overlap_comm: bool, # force overlap comm dp_process_group: Optional[ProcessGroup], # the dp pg for comm tp_process_group: Optional[ProcessGroup], # if using tp pp_process_group: Optional[ProcessGroup], # if using pp @@ -65,17 +64,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): forced_dtype: Optional[torch.dtype] = None, overlap_allgather: bool = False, ): - WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result." - if not force_overlap_comm and (overlap_communication or partition_grad): - raise RuntimeError( - WARN_STR - + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True" - ) - - if force_overlap_comm: - overlap_communication = True - warnings.warn(WARN_STR + " Please make sure of this.") - pg_param_list = { dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), moe_dp_group: list(filter(is_moe_tensor, model.parameters())), @@ -116,9 +104,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin Extra Args: ep_size (int): The size of expert parallelism - force_overlap_comm (bool): - For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training. - This flag is used to force overlap_communication=True. Make sure every expert are routed when you use this. """ def __init__( @@ -167,8 +152,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): dp_outside: bool = True, overlap_p2p: bool = True, overlap_allgather: bool = False, - force_overlap_comm: bool = False, ) -> None: + if overlap_communication or zero_stage == 2: + overlap_communication = False + zero_stage = 1 + warnings.warn( + f"overlap_communication and zero_stage are set to False and 1 because " + f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " + ) + assert ( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" @@ -326,7 +318,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) self.max_norm = max_norm - self.force_overlap_comm = force_overlap_comm def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( @@ -421,7 +412,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): optimizer, model, use_pipeline=self.enable_pipeline_parallelism, - force_overlap_comm=self.force_overlap_comm, param_info=param_info, dp_process_group=dp_group, tp_process_group=self.tp_group,