[moe] fix MoE bugs (#1628)

* remove forced FP32 modules

* correct no_shard-contexts' positions
This commit is contained in:
HELSON
2022-09-22 13:56:30 +08:00
committed by GitHub
parent 38c68b5b9a
commit f7f2248771
7 changed files with 26 additions and 33 deletions

View File

@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
layer = layer.to(get_current_device())
if data_type == torch.float16:
layer = layer.half()
@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# save all results
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
o_gt_grad = layer.gate_weight.grad.data.clone()
# reset all gradients
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.gate_weight.grad.zero_()
layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel
@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
new_out.backward(grad) # get new type gradient
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
n_gt_grad = layer.gate_weight.grad.data.clone()
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)

View File

@@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
for name, param in model.named_parameters():
assert hasattr(param, 'colo_attr')
# the weights in the gate should be fp32
if 'gate' in name:
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
else:
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.colo_attr.sharded_data_tensor.is_sharded
assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
else:
assert param.colo_attr.sharded_data_tensor.is_sharded

View File

@@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
apex_grad_handler = MoeGradientHandler(model)
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.colo_attr.data_payload)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break

View File

@@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
else:
zero_p = zero_p.colo_attr.data_payload.to(p.device)
assert p.dtype == zero_p.dtype
assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'