diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index cad9ca95c..03b7bebb1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -27,6 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager @@ -1068,8 +1069,10 @@ class HybridParallelPlugin(PipelinePluginBase): self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) - self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]) - + self.logger.info( + f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0] + ) + self.stage_manager = None self.schedule = None self.custom_policy = custom_policy diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 56b731d13..a02deb80d 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -15,6 +15,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( HybridParallelModule, HybridParallelNaiveOptimizer, HybridParallelPlugin, + HybridParallelZeroOptimizer, get_param_info, reinitialize_optimizer, ) @@ -22,16 +23,18 @@ from colossalai.checkpoint_io import MoECheckpointIO from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.tensor.moe_tensor.api import is_moe_tensor -from colossalai.zero.low_level import LowLevelZeroOptimizer -class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): + +class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): def __init__( self, optimizer: Optimizer, model: Module, use_pipeline: bool, force_overlap_comm: bool, # force overlap comm - dp_process_group: ProcessGroup, # dp pg for comm + dp_process_group: Optional[ProcessGroup], # the dp pg for comm + tp_process_group: Optional[ProcessGroup], # if using tp + pp_process_group: Optional[ProcessGroup], # if using pp moe_dp_group: ProcessGroup, # moe dp pg for comm param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config @@ -49,32 +52,28 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload forced_dtype: Optional[torch.dtype] = None, - ): - + ): WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result" if not force_overlap_comm and (overlap_communication or partition_grad): - raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True") - + raise RuntimeError( + WARN_STR + + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True" + ) + if force_overlap_comm: overlap_communication = True warnings.warn(WARN_STR + " Please make sure of this.") - self.param_info = param_info - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.dp_pg = dp_process_group - - if use_pipeline: - reinitialize_optimizer(optimizer, model) - pg_param_list = { dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), moe_dp_group: list(filter(is_moe_tensor, model.parameters())), } super().__init__( + model=model, optimizer=optimizer, - pg_to_param_list=pg_param_list, + use_pipeline=use_pipeline, + param_info=param_info, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -89,7 +88,12 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): overlap_communication=overlap_communication, partition_grad=partition_grad, cpu_offload=cpu_offload, + # dp_process_group=dp_process_group, + tp_process_group=tp_process_group, + pp_process_group=pp_process_group, forced_dtype=forced_dtype, + ## moe args + pg_to_param_list=pg_param_list, ) @@ -180,7 +184,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info ) else: - if not(self.dp_size > 1 or self.moe_dp_size > 1): + if not (self.dp_size > 1 or self.moe_dp_size > 1): warnings.warn( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you do not intend to use cpu_offload, please consider set zero_stage=0." @@ -192,6 +196,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): force_overlap_comm=self.force_overlap_comm, param_info=param_info, dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + pp_process_group=self.pp_group, moe_dp_group=self.moe_dp_group, verbose=True, clip_grad_norm=self.max_norm, diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 4e9d3878b..123e590c9 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -117,23 +117,35 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 1, "ep_size": 1, - "zero_stage": 2, + "zero_stage": 1, + "overlap_communication": False, "precision": "fp32", - }, # [dp(2) + pp(2)] + [moe_dp(4)] + }, # [dp(4)] + [moe_dp(4)] + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "ep_size": 1, + "zero_stage": 1, + "overlap_communication": False, + "precision": "fp32", + }, # [dp(2) + pp(2)] + [moe_pp(2)] + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "ep_size": 1, + "zero_stage": 1, + "overlap_communication": False, + "precision": "fp32", + }, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass # { # "tp_size": 1, # "pp_size": 2, # "num_microbatches": 2, - # "ep_size": 1, - # "zero_stage": 1, - # "precision": "fp32", - # }, # [dp(2) + pp(2)] + [moe_dp(4)] - # { - # "tp_size": 1, - # "pp_size": 2, - # "num_microbatches": 2, - # "ep_size": 4, + # "ep_size": 2, # "zero_stage": 1, + # "overlap_communication": False, # "precision": "fp32", # }, # [dp(2) + pp(2)] + [ep(4))] # { @@ -141,13 +153,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # "pp_size": 1, # "ep_size": 2, # "zero_stage": 0, + # "overlap_communication": False, # "precision": "fp32", # }, # [dp(4)] + [ep(2) + moe_tp(2)] # { - # "tp_size": 1, - # "pp_size": 1, - # "ep_size": 4, - # "zero_stage": 0, + # "tp_size": 1, + # "pp_size": 1, + # "ep_size": 4, + # "overlap_communication": False, + # "zero_stage": 0, # "precision": "fp32" # }, # full dp for non-moe and full ep for moe ],