mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[moe] fix MoE bugs (#1628)
* remove forced FP32 modules * correct no_shard-contexts' positions
This commit is contained in:
@@ -24,6 +24,7 @@ class MoeExperts(nn.Module):
|
||||
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
|
||||
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=False)
|
||||
class Experts(MoeExperts):
|
||||
"""A wrapper class to create experts. It will create E experts across the
|
||||
moe model parallel group, where E is the number of experts. Every expert
|
||||
@@ -35,7 +36,6 @@ class Experts(MoeExperts):
|
||||
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=False)
|
||||
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
|
||||
super().__init__("all_to_all", num_experts)
|
||||
|
||||
|
@@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module):
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=True)
|
||||
class MoeLayer(nn.Module):
|
||||
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
||||
to router all tokens, is mainly used to exchange all tokens for every expert across
|
||||
@@ -241,12 +242,11 @@ class MoeLayer(nn.Module):
|
||||
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
|
||||
"""
|
||||
|
||||
@no_shard_zero_decrator(is_replicated=True)
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
self.num_experts = num_experts
|
||||
self.gate = FP32LinearGate(dim_model, num_experts)
|
||||
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
|
||||
self.router = router
|
||||
self.experts = experts
|
||||
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
||||
@@ -254,16 +254,14 @@ class MoeLayer(nn.Module):
|
||||
self.ep_size = experts.dist_info.ep_size
|
||||
self.num_local_experts = experts.num_local_experts
|
||||
|
||||
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
|
||||
|
||||
def a2a_process(self, dispatch_data: torch.Tensor):
|
||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
||||
|
||||
input_shape = expert_input.shape
|
||||
|
||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
||||
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = expert_output.reshape(input_shape)
|
||||
|
||||
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
||||
return expert_output
|
||||
|
||||
@@ -274,16 +272,22 @@ class MoeLayer(nn.Module):
|
||||
return expert_out
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
# reshape the input tokens
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
|
||||
gate_output = self.gate(fp32_input)
|
||||
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||
|
||||
# the data type of the inputs in the gating should be fp32
|
||||
fp32_input = tokens.to(torch.float)
|
||||
fp32_weight = self.gate_weight.to(torch.float)
|
||||
gate_output = F.linear(fp32_input, fp32_weight)
|
||||
|
||||
# the result from the router
|
||||
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
||||
|
||||
if self.use_kernel:
|
||||
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
|
||||
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
|
||||
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
||||
else:
|
||||
sec_mask_f = router_res[1].type_as(inputs)
|
||||
sec_mask_f = route_result_list[1].type_as(inputs)
|
||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
# dispatch_data [e, c, h]
|
||||
@@ -295,12 +299,11 @@ class MoeLayer(nn.Module):
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
||||
"build function.")
|
||||
# expert_output [e, c, h]
|
||||
|
||||
if self.use_kernel:
|
||||
expert_output = expert_output.reshape(-1, self.d_model)
|
||||
ans = MoeCombine.apply(expert_output, *router_res)
|
||||
ans = MoeCombine.apply(expert_output, *route_result_list)
|
||||
else:
|
||||
combine_weights = router_res[0].type_as(inputs)
|
||||
combine_weights = route_result_list[0].type_as(inputs)
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
@@ -258,7 +258,8 @@ def no_shard_zero_decrator(is_replicated: bool = True):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context(is_replicated):
|
||||
init_func(*args, **kwargs)
|
||||
ret = init_func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
return _no_shard
|
||||
|
||||
|
Reference in New Issue
Block a user