mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[moe] fix MoE bugs (#1628)
* remove forced FP32 modules * correct no_shard-contexts' positions
This commit is contained in:
@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
|
||||
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
|
||||
layer = layer.to(get_current_device())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
|
||||
@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
|
||||
# save all results
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
o_gt_grad = layer.gate_weight.grad.data.clone()
|
||||
|
||||
# reset all gradients
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
layer.gate_weight.grad.zero_()
|
||||
|
||||
layer.use_kernel = True
|
||||
new_out = layer(tokens) # get ouputs through colossal kernel
|
||||
@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
|
||||
new_out.backward(grad) # get new type gradient
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
n_gt_grad = layer.gate_weight.grad.data.clone()
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
|
@@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
||||
for name, param in model.named_parameters():
|
||||
assert hasattr(param, 'colo_attr')
|
||||
|
||||
# the weights in the gate should be fp32
|
||||
if 'gate' in name:
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
|
||||
else:
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
|
||||
# the parameters in moe experts and its gate should not be sharded
|
||||
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
|
||||
assert not param.colo_attr.sharded_data_tensor.is_sharded
|
||||
assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
|
||||
else:
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
|
||||
|
@@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||
apex_grad_handler = MoeGradientHandler(model)
|
||||
|
||||
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
|
||||
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
|
||||
if 'gate' in n:
|
||||
p.data = p.float()
|
||||
p.data.copy_(zp.colo_attr.data_payload)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
break
|
||||
|
@@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
||||
else:
|
||||
zero_p = zero_p.colo_attr.data_payload.to(p.device)
|
||||
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
|
||||
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
||||
|
Reference in New Issue
Block a user