[moe] fix MoE bugs (#1628)

* remove forced FP32 modules

* correct no_shard-contexts' positions
This commit is contained in:
HELSON 2022-09-22 13:56:30 +08:00 committed by GitHub
parent 38c68b5b9a
commit f7f2248771
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 26 additions and 33 deletions

View File

@ -24,6 +24,7 @@ class MoeExperts(nn.Module):
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts): class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the """A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert moe model parallel group, where E is the number of experts. Every expert
@ -35,7 +36,6 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class expert_args: Args used to initialize experts, the args could be found in corresponding expert class
""" """
@no_shard_zero_decrator(is_replicated=False)
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts) super().__init__("all_to_all", num_experts)

View File

@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module):
return F.linear(x, self.weight) return F.linear(x, self.weight)
@no_shard_zero_decrator(is_replicated=True)
class MoeLayer(nn.Module): class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits """A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across to router all tokens, is mainly used to exchange all tokens for every expert across
@ -241,12 +242,11 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert. experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
""" """
@no_shard_zero_decrator(is_replicated=True)
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__() super().__init__()
self.d_model = dim_model self.d_model = dim_model
self.num_experts = num_experts self.num_experts = num_experts
self.gate = FP32LinearGate(dim_model, num_experts) self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router = router self.router = router
self.experts = experts self.experts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
@ -254,16 +254,14 @@ class MoeLayer(nn.Module):
self.ep_size = experts.dist_info.ep_size self.ep_size = experts.dist_info.ep_size
self.num_local_experts = experts.num_local_experts self.num_local_experts = experts.num_local_experts
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
def a2a_process(self, dispatch_data: torch.Tensor): def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group) expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape) expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group) expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output return expert_output
@ -274,16 +272,22 @@ class MoeLayer(nn.Module):
return expert_out return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model) tokens = inputs.reshape(-1, self.d_model)
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
gate_output = self.gate(fp32_input) # the data type of the inputs in the gating should be fp32
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel: if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:]) dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else: else:
sec_mask_f = router_res[1].type_as(inputs) sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h] # dispatch_data [e, c, h]
@ -295,12 +299,11 @@ class MoeLayer(nn.Module):
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.") "build function.")
# expert_output [e, c, h] # expert_output [e, c, h]
if self.use_kernel: if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model) expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res) ans = MoeCombine.apply(expert_output, *route_result_list)
else: else:
combine_weights = router_res[0].type_as(inputs) combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1) combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1]) expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)

View File

@ -258,7 +258,8 @@ def no_shard_zero_decrator(is_replicated: bool = True):
def _no_shard(*args, **kwargs): def _no_shard(*args, **kwargs):
with no_shard_zero_context(is_replicated): with no_shard_zero_context(is_replicated):
init_func(*args, **kwargs) ret = init_func(*args, **kwargs)
return ret
return _no_shard return _no_shard

View File

@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device()) expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor) expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
layer = layer.to(get_current_device())
if data_type == torch.float16: if data_type == torch.float16:
layer = layer.half() layer = layer.half()
@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# save all results # save all results
o_tk_grad = tokens.grad.data.clone() o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone() o_gt_grad = layer.gate_weight.grad.data.clone()
# reset all gradients # reset all gradients
tokens.grad.zero_() tokens.grad.zero_()
layer.gate.weight.grad.zero_() layer.gate_weight.grad.zero_()
layer.use_kernel = True layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel new_out = layer(tokens) # get ouputs through colossal kernel
@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
new_out.backward(grad) # get new type gradient new_out.backward(grad) # get new type gradient
n_tk_grad = tokens.grad.data.clone() n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone() n_gt_grad = layer.gate_weight.grad.data.clone()
if data_type == torch.float32: if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad) check_equal(o_tk_grad, n_tk_grad)

View File

@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
for name, param in model.named_parameters(): for name, param in model.named_parameters():
assert hasattr(param, 'colo_attr') assert hasattr(param, 'colo_attr')
# the weights in the gate should be fp32
if 'gate' in name:
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
else:
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded # the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.colo_attr.sharded_data_tensor.is_sharded assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
else: else:
assert param.colo_attr.sharded_data_tensor.is_sharded assert param.colo_attr.sharded_data_tensor.is_sharded

View File

@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
apex_grad_handler = MoeGradientHandler(model) apex_grad_handler = MoeGradientHandler(model)
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.colo_attr.data_payload)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 5: if i > 5:
break break

View File

@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
else: else:
zero_p = zero_p.colo_attr.data_payload.to(p.device) zero_p = zero_p.colo_attr.data_payload.to(p.device)
assert p.dtype == zero_p.dtype assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'