mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[moe]: fix ep/tp tests, add hierarchical all2all (#4982)
* fix: add warning for EP different behavior * fix: use shard_data in ep & tp model * to: add used_capacity * fix: fix router test * feat: add create_ep_node_group * feat: add create_ep_hierarchical_group fn * feat: add HierarchicalAllToAll * test: add hierarchical all2all test * fix: fix test errors * fix: simplify create_ep_hierarchical_group * fix: add hierarchical_alltoall arg * fix: fix environ typo * revert: revert process mesh order * to: add todo mark * fix: skip hierarchical_comm if torch < 1.13.1
This commit is contained in:
@@ -7,12 +7,12 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
||||
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
|
||||
from colossalai.moe.experts import MLPExperts
|
||||
from colossalai.moe.load_balance import LoadBalancer
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.routers import MoeRouter, get_router_cls
|
||||
from colossalai.moe.utils import get_noise_generator
|
||||
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
|
||||
|
||||
|
||||
@@ -51,19 +51,20 @@ class SparseMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
router_top_k: int = 1,
|
||||
router_capacity_factor_train: Optional[float] = 1.25,
|
||||
router_capacity_factor_eval: Optional[float] = 2.0,
|
||||
router_min_capacity: Optional[int] = 4,
|
||||
router_capacity_factor_train: float = 1.25,
|
||||
router_capacity_factor_eval: float = 2.0,
|
||||
router_min_capacity: int = 4,
|
||||
router_noisy_policy: Optional[str] = None,
|
||||
router_drop_tks: Optional[bool] = True,
|
||||
router_drop_tks: bool = True,
|
||||
mlp_activation: Optional[str] = None,
|
||||
mlp_gated: Optional[bool] = False,
|
||||
enable_load_balance: Optional[bool] = False,
|
||||
load_balance_tolerance: Optional[float] = 0.1,
|
||||
load_balance_beam_width: Optional[int] = 8,
|
||||
load_balance_group_swap_factor: Optional[float] = 0.4,
|
||||
enable_kernel: Optional[bool] = False,
|
||||
enable_comm_overlap: Optional[bool] = False,
|
||||
mlp_gated: bool = False,
|
||||
enable_load_balance: bool = False,
|
||||
load_balance_tolerance: float = 0.1,
|
||||
load_balance_beam_width: int = 8,
|
||||
load_balance_group_swap_factor: float = 0.4,
|
||||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_comm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -104,6 +105,8 @@ class SparseMLP(nn.Module):
|
||||
if self.expert_parallel is not None:
|
||||
self.ep_group = get_ep_group(self.experts)
|
||||
self.ep_size = get_ep_size(self.experts)
|
||||
self.ep_hierarchical_group = create_ep_hierarchical_group(
|
||||
self.ep_group) if enable_hierarchical_comm else None
|
||||
self.dp_group = get_dp_group(self.experts)
|
||||
else:
|
||||
self.ep_group = None
|
||||
@@ -132,7 +135,7 @@ class SparseMLP(nn.Module):
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
|
||||
@@ -158,7 +161,8 @@ class SparseMLP(nn.Module):
|
||||
self.load_balancer.update_load(expert_load)
|
||||
|
||||
# the result from the router
|
||||
route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
||||
used_capacity, *route_result_list = self.router(
|
||||
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
||||
|
||||
# dispatch_data: (num_experts, capacity, hidden_size)
|
||||
if self.enable_kernel:
|
||||
@@ -170,9 +174,17 @@ class SparseMLP(nn.Module):
|
||||
|
||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||
if self.expert_parallel == "EP":
|
||||
expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap)
|
||||
expert_output = self._ep_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
elif self.expert_parallel == "TP":
|
||||
expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap)
|
||||
expert_output = self._tp_process(
|
||||
dispatch_data,
|
||||
used_capacity,
|
||||
overlap=self.enable_comm_overlap
|
||||
)
|
||||
elif self.expert_parallel is None:
|
||||
expert_output = self._local_process(dispatch_data)
|
||||
else:
|
||||
@@ -196,7 +208,12 @@ class SparseMLP(nn.Module):
|
||||
expert_out = self.experts(expert_in)
|
||||
return expert_out
|
||||
|
||||
def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
|
||||
def _ep_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expert Parallel
|
||||
|
||||
@@ -207,12 +224,18 @@ class SparseMLP(nn.Module):
|
||||
torch.Tensor: (num_experts, capacity, hidden_size)
|
||||
"""
|
||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
|
||||
return expert_output
|
||||
|
||||
if self.ep_hierarchical_group is not None:
|
||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group)
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group)
|
||||
return expert_output
|
||||
else:
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
|
||||
return expert_output
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -261,7 +284,12 @@ class SparseMLP(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
|
||||
def _tp_process(
|
||||
self,
|
||||
dispatch_data: torch.Tensor,
|
||||
used_capacity: torch.Tensor,
|
||||
overlap: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
without overlap:
|
||||
| C |
|
||||
@@ -295,8 +323,8 @@ class SparseMLP(nn.Module):
|
||||
NUM_CHUNK = 4
|
||||
NUM_STAGES = 4
|
||||
|
||||
assert (dispatch_data.shape[0] % NUM_CHUNK == 0
|
||||
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
|
||||
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
||||
output = torch.empty_like(dispatch_data)
|
||||
|
Reference in New Issue
Block a user