diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py index 02cea5e67..b499345d4 100644 --- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -1,45 +1,46 @@ -from colossalai.core import global_context as gpc -from colossalai.registry import GRADIENT_HANDLER -from colossalai.utils.moe import get_moe_epsize_param_dict -from ._base_gradient_handler import BaseGradientHandler -from ...context.parallel_mode import ParallelMode -from .utils import bucket_allreduce -from colossalai.context.moe_context import MOE_CONTEXT - - -@GRADIENT_HANDLER.register_module -class MoeGradientHandler(BaseGradientHandler): - """A helper class to handle all-reduce operations in a data parallel group and - moe model parallel. A all-reduce collective communication will be operated in - :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are - the same type to improve the efficiency of communication. - - Args: - model (Module): Model where the gradients accumulate. - optimizer (Optimizer): Optimizer for updating the parameters. - """ - - def __init__(self, model, optimizer=None): - super().__init__(model, optimizer) - - def handle_gradient(self): - """A method running an all-reduce operation in a data parallel group. - Then running an all-reduce operation for all parameters in experts - across moe model parallel group - """ - global_data = gpc.data_parallel_size - - if global_data > 1: - epsize_param_dict = get_moe_epsize_param_dict(self._model) - - # epsize is 1, indicating the params are replicated among processes in data parallelism - # use the ParallelMode.DATA to get data parallel group - # reduce gradients for all parameters in data parallelism - if 1 in epsize_param_dict: - bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in epsize_param_dict: - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from colossalai.utils.moe import get_moe_epsize_param_dict + +from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler +from .utils import bucket_allreduce + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + global_data = gpc.data_parallel_size + + if global_data > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + bucket_allreduce(param_list=epsize_param_dict[ep_size], + group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)