diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index 62f3708fc..7e0e6ffdd 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -1,10 +1,18 @@ from .gemini_plugin import GeminiPlugin from .hybrid_parallel_plugin import HybridParallelPlugin from .low_level_zero_plugin import LowLevelZeroPlugin +from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] +__all__ = [ + "Plugin", + "TorchDDPPlugin", + "GeminiPlugin", + "LowLevelZeroPlugin", + "HybridParallelPlugin", + "MoeHybridParallelPlugin", +] import torch from packaging import version diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index f9721c79e..0fb858d78 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -40,21 +40,19 @@ class MixtralPolicy(Policy): if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") - if getattr(self.shard_config, "ep_group", None) is None: - raise ValueError("You must pass in ep_group via shard_config for expert parallel!") - - # expert parallel - self.append_or_create_submodule_replacement( - description=[ - SubModuleReplacementDescription( - suffix="block_sparse_moe", - target_module=EPMixtralSparseMoeBlock, - kwargs={"ep_group": self.shard_config.ep_group}, - ) - ], - policy=policy, - target_key=MixtralDecoderLayer, - ) + if getattr(self.shard_config, "ep_group", None) is not None: + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + kwargs={"ep_group": self.shard_config.ep_group}, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) # optimization configuration if self.shard_config.enable_fused_normalization: