mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[moe] support mixtral (#5309)
* [moe] add mixtral block for single expert * [moe] mixtral block fwd support uneven ep * [moe] mixtral block bwd support uneven ep * [moe] add mixtral moe layer * [moe] simplify replace * [meo] support save sharded mixtral * [meo] support load sharded mixtral * [meo] support save sharded optim * [meo] integrate moe manager into plug * [meo] fix optimizer load * [meo] fix mixtral layer
This commit is contained in:
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
@@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=self.real_dp_size,
|
||||
fixed_ep_size=ep_size,
|
||||
fixed_pp_size=pp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_info = MOE_MANAGER.get_info(0)[1]
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
|
Reference in New Issue
Block a user