[moe] fix MoE bugs (#1628)

* remove forced FP32 modules

* correct no_shard-contexts' positions
This commit is contained in:
HELSON
2022-09-22 13:56:30 +08:00
committed by GitHub
parent 38c68b5b9a
commit f7f2248771
7 changed files with 26 additions and 33 deletions

View File

@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
layer = layer.to(get_current_device())
if data_type == torch.float16:
layer = layer.half()
@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# save all results
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
o_gt_grad = layer.gate_weight.grad.data.clone()
# reset all gradients
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.gate_weight.grad.zero_()
layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel
@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
new_out.backward(grad) # get new type gradient
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
n_gt_grad = layer.gate_weight.grad.data.clone()
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)