From 6cd4c32be4c0ced9a70e228530f383c5f4a580de Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 3 Jul 2024 20:02:19 +0800 Subject: [PATCH] [shardformer] fix the moe (#5883) --- colossalai/booster/plugin/__init__.py | 10 +++++++- colossalai/shardformer/policies/mixtral.py | 28 ++++++++++------------ 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index 62f3708fc..7e0e6ffdd 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,10 +1,18 @@ from .gemini_plugin import GeminiPlugin from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin +from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] +__all__ = [ + "Plugin", + "TorchDDPPlugin", + "GeminiPlugin", + "LowLevelZeroPlugin", + "HybridParallelPlugin", + "MoeHybridParallelPlugin", +] import torch from packaging import version diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index f9721c79e..0fb858d78 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,21 +40,19 @@ class MixtralPolicy(Policy): if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") - if getattr(self.shard_config, "ep_group", None) is None: - raise ValueError("You must pass in ep_group via shard_config for expert parallel!") - - # expert parallel - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="block_sparse_moe", - target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group}, - ) - ], - policy=policy, - target_key=MixtralDecoderLayer, - ) + if getattr(self.shard_config, "ep_group", None) is not None: + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) # optimization configuration if self.shard_config.enable_fused_normalization: