mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 11:06:25 +00:00
[shardformer] fix the moe (#5883)
This commit is contained in:
parent
eb24fcd914
commit
6cd4c32be4
@ -1,10 +1,18 @@
|
||||
from .gemini_plugin import GeminiPlugin
|
||||
from .hybrid_parallel_plugin import HybridParallelPlugin
|
||||
from .low_level_zero_plugin import LowLevelZeroPlugin
|
||||
from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from .plugin_base import Plugin
|
||||
from .torch_ddp_plugin import TorchDDPPlugin
|
||||
|
||||
__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"]
|
||||
__all__ = [
|
||||
"Plugin",
|
||||
"TorchDDPPlugin",
|
||||
"GeminiPlugin",
|
||||
"LowLevelZeroPlugin",
|
||||
"HybridParallelPlugin",
|
||||
"MoeHybridParallelPlugin",
|
||||
]
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
@ -40,9 +40,7 @@ class MixtralPolicy(Policy):
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
if getattr(self.shard_config, "ep_group", None) is None:
|
||||
raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
|
||||
|
||||
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
|
Loading…
Reference in New Issue
Block a user