[moe]: fix ep/tp tests, add hierarchical all2all (#4982)

* fix: add warning for EP different behavior

* fix: use shard_data in ep & tp model

* to: add used_capacity

* fix: fix router test

* feat: add create_ep_node_group

* feat: add create_ep_hierarchical_group fn

* feat: add HierarchicalAllToAll

* test: add hierarchical all2all test

* fix: fix test errors

* fix: simplify create_ep_hierarchical_group

* fix: add hierarchical_alltoall arg

* fix: fix environ typo

* revert: revert process mesh order

* to: add todo mark

* fix: skip hierarchical_comm if torch < 1.13.1
This commit is contained in:
Wenhao Chen
2023-11-09 14:31:00 +08:00
committed by GitHub
parent 239cd92eff
commit 724441279b
10 changed files with 388 additions and 164 deletions

View File

@@ -7,12 +7,12 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import get_noise_generator
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size
@@ -51,19 +51,20 @@ class SparseMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_capacity_factor_train: Optional[float] = 1.25,
router_capacity_factor_eval: Optional[float] = 2.0,
router_min_capacity: Optional[int] = 4,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
router_noisy_policy: Optional[str] = None,
router_drop_tks: Optional[bool] = True,
router_drop_tks: bool = True,
mlp_activation: Optional[str] = None,
mlp_gated: Optional[bool] = False,
enable_load_balance: Optional[bool] = False,
load_balance_tolerance: Optional[float] = 0.1,
load_balance_beam_width: Optional[int] = 8,
load_balance_group_swap_factor: Optional[float] = 0.4,
enable_kernel: Optional[bool] = False,
enable_comm_overlap: Optional[bool] = False,
mlp_gated: bool = False,
enable_load_balance: bool = False,
load_balance_tolerance: float = 0.1,
load_balance_beam_width: int = 8,
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
@@ -104,6 +105,8 @@ class SparseMLP(nn.Module):
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = create_ep_hierarchical_group(
self.ep_group) if enable_hierarchical_comm else None
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
@@ -132,7 +135,7 @@ class SparseMLP(nn.Module):
def reset_parameters(self):
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
@@ -158,7 +161,8 @@ class SparseMLP(nn.Module):
self.load_balancer.update_load(expert_load)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
@@ -170,9 +174,17 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap)
expert_output = self._ep_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap)
expert_output = self._tp_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
@@ -196,7 +208,12 @@ class SparseMLP(nn.Module):
expert_out = self.experts(expert_in)
return expert_out
def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
def _ep_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor:
"""
Expert Parallel
@@ -207,12 +224,18 @@ class SparseMLP(nn.Module):
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output
if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output
else:
@dataclasses.dataclass
@@ -261,7 +284,12 @@ class SparseMLP(nn.Module):
return output
def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
def _tp_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor:
"""
without overlap:
| C |
@@ -295,8 +323,8 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4
NUM_STAGES = 4
assert (dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)