mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[moe] merge moe into main (#4978)
* update moe module * support openmoe
This commit is contained in:
0
colossalai/tensor/moe_tensor/__init__.py
Normal file
0
colossalai/tensor/moe_tensor/__init__.py
Normal file
137
colossalai/tensor/moe_tensor/api.py
Normal file
137
colossalai/tensor/moe_tensor/api.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .moe_info import MoeParallelInfo
|
||||
|
||||
|
||||
def is_moe_tensor(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check whether the given tensor is a moe tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the given tensor is a moe tensor.
|
||||
"""
|
||||
return hasattr(tensor, "moe_info")
|
||||
|
||||
|
||||
def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None:
|
||||
"""
|
||||
Set moe info for the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be set.
|
||||
moe_info (dict): The moe info to be set.
|
||||
|
||||
"""
|
||||
tensor.__setattr__("moe_info", moe_info)
|
||||
|
||||
|
||||
def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo:
|
||||
"""
|
||||
Get moe info for the given tensor.
|
||||
|
||||
Args:
|
||||
ep_size (int): The expert parallel size.
|
||||
dp_size (int): The data parallel size.
|
||||
pp_size (int): The pipeline parallel size.
|
||||
ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle.
|
||||
|
||||
Returns:
|
||||
dict: The moe info of the given tensor.
|
||||
"""
|
||||
return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size)
|
||||
|
||||
|
||||
def get_ep_group(tensor: torch.Tensor) -> ProcessGroup:
|
||||
"""
|
||||
Get the expert parallel group of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
torch.distributed.ProcessGroup: The expert parallel group of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_group
|
||||
|
||||
|
||||
def get_ep_size(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the expert parallel size of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The expert parallel size of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_size
|
||||
|
||||
|
||||
def get_dp_group(tensor: torch.Tensor) -> ProcessGroup:
|
||||
"""
|
||||
Get the data parallel group of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
torch.distributed.ProcessGroup: The data parallel group of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.dp_group
|
||||
|
||||
|
||||
def get_ep_rank(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the expert parallel rank of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The expert parallel rank of the given tensor.
|
||||
"""
|
||||
return dist.get_rank(get_ep_group(tensor))
|
||||
|
||||
|
||||
def get_dp_rank(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the data parallel rank of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The data parallel rank of the given tensor.
|
||||
"""
|
||||
return dist.get_rank(get_dp_group(tensor))
|
||||
|
||||
|
||||
def get_ep_group_ranks(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the expert parallel group ranks of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The expert parallel group ranks of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.ep_group_ranks
|
||||
|
||||
|
||||
def get_dp_group_ranks(tensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Get the data parallel group ranks of the given tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
int: The data parallel group ranks of the given tensor.
|
||||
"""
|
||||
return tensor.moe_info.dp_group_ranks
|
28
colossalai/tensor/moe_tensor/moe_info.py
Normal file
28
colossalai/tensor/moe_tensor/moe_info.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
|
||||
|
||||
class MoeParallelInfo:
|
||||
"""Moe parallelism information, storing parallel sizes and groups."""
|
||||
|
||||
def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1):
|
||||
"""
|
||||
init MoeParallelInfo with ep_size, dp_size and pp_size
|
||||
|
||||
Args:
|
||||
ep_size (int): expert parallel size
|
||||
dp_size (int): data parallel (zero) size
|
||||
pp_size (int, optional): pipeline parallel size. Defaults to 1.
|
||||
ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
|
||||
"""
|
||||
self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size
|
||||
if ep_inside:
|
||||
self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2
|
||||
self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size)
|
||||
else:
|
||||
self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2
|
||||
self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size)
|
||||
|
||||
self.ep_group = self.pg.get_group_along_axis(self.ep_axis)
|
||||
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
|
||||
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
|
||||
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
|
Reference in New Issue
Block a user