mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[moe] merge moe into main (#4978)
* update moe module * support openmoe
This commit is contained in:
162
colossalai/moe/manager.py
Normal file
162
colossalai/moe/manager.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.moe_tensor.api import get_moe_info
|
||||
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
|
||||
|
||||
class MoeManager(metaclass=SingletonMeta):
|
||||
"""MoE manager. This class manages different
|
||||
parallel groups in MoE context and MoE loss in training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.parallel = None
|
||||
self.seed = None
|
||||
self.mode = None
|
||||
self.use_ep_inside = None
|
||||
self.world_size = None
|
||||
self._parallel_info_dict = dict()
|
||||
|
||||
# router
|
||||
self.router_aux_loss = []
|
||||
self.router_z_loss = []
|
||||
|
||||
# fixed mode
|
||||
self.pp_size = None
|
||||
self.dp_size = None
|
||||
self.ep_size = None
|
||||
|
||||
# dynamic mode
|
||||
# Users may want to set maximum expert parallel size smaller than the world size
|
||||
# since very low bandwidth across nodes may constrain the performance of MoE
|
||||
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
|
||||
self.max_ep_size = None
|
||||
|
||||
self.has_setup = False
|
||||
|
||||
@property
|
||||
def parallel_info_dict(self):
|
||||
return self._parallel_info_dict
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self.has_setup
|
||||
|
||||
def setup(
|
||||
self,
|
||||
seed: int,
|
||||
parallel: str = None,
|
||||
mode: str = "dynamic",
|
||||
max_ep_size: int = 8,
|
||||
fixed_dp_size: int = 0,
|
||||
fixed_ep_size: int = 0,
|
||||
fixed_pp_size: int = 0,
|
||||
use_ep_inside: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Setup MoE distributed context.
|
||||
|
||||
Args:
|
||||
seed (int): Random seed. Defaults to 42.
|
||||
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
|
||||
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
|
||||
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
|
||||
In fixed mode, the ep size and dp size is fixed.
|
||||
In dynamic mode, the ep size and dp size will be changed according to num experts.
|
||||
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
|
||||
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
|
||||
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
|
||||
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
|
||||
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
|
||||
"""
|
||||
assert (not self.is_initialized), "MoE distributed context shouldn't be set up again"
|
||||
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
|
||||
|
||||
self.seed = seed + dist.get_rank()
|
||||
self.parallel = parallel
|
||||
self.use_ep_inside = use_ep_inside
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
# init by mode
|
||||
self.mode = mode
|
||||
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
|
||||
if self.mode == "dynamic":
|
||||
self.max_ep_size = min(max_ep_size, self.world_size)
|
||||
else:
|
||||
assert (fixed_dp_size > 0 and fixed_ep_size > 0
|
||||
and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0"
|
||||
assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int)
|
||||
and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int"
|
||||
self.ep_size = fixed_ep_size
|
||||
self.dp_size = fixed_dp_size
|
||||
self.pp_size = fixed_pp_size
|
||||
|
||||
self.has_setup = True
|
||||
|
||||
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
|
||||
"""Calculate the Data Parallel Group and Expert Parallel Group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_experts : int
|
||||
The number experts
|
||||
|
||||
Returns
|
||||
-------
|
||||
int, MoeParallelInfo
|
||||
number of local experts, the MoeParallelInfo of the current ep_size
|
||||
"""
|
||||
|
||||
if self.mode == "dynamic":
|
||||
gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater
|
||||
lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less
|
||||
assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number"
|
||||
" is not a multiple of ep size or vice versa.")
|
||||
dp_size = 1 if gt_flag else self.world_size // num_experts
|
||||
ep_size = min(self.world_size // dp_size, self.max_ep_size)
|
||||
dp_size = self.world_size // ep_size
|
||||
pp_size = 1
|
||||
else:
|
||||
dp_size = self.dp_size
|
||||
ep_size = self.ep_size
|
||||
pp_size = self.pp_size
|
||||
|
||||
# Calculate the number of experts for each GPU
|
||||
if use_tp:
|
||||
num_local_experts = num_experts
|
||||
else:
|
||||
if self.mode == "dynamic":
|
||||
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
|
||||
else:
|
||||
num_local_experts = num_experts // ep_size
|
||||
|
||||
if not (ep_size in self.parallel_info_dict):
|
||||
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
|
||||
if dist.get_rank() == 0:
|
||||
if self.use_ep_inside:
|
||||
print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
|
||||
else:
|
||||
print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
|
||||
|
||||
return num_local_experts, self.parallel_info_dict[ep_size]
|
||||
|
||||
def reset_loss(self):
|
||||
self.router_aux_loss, self.router_z_loss = [], []
|
||||
|
||||
def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
|
||||
self.router_aux_loss.append(aux_loss)
|
||||
self.router_z_loss.append(z_loss)
|
||||
|
||||
def get_loss(self):
|
||||
cur_loss = self.router_aux_loss, self.router_z_loss
|
||||
return cur_loss
|
||||
|
||||
def get_parallel(self):
|
||||
return self.parallel
|
||||
|
||||
|
||||
MOE_MANAGER = MoeManager()
|
Reference in New Issue
Block a user