mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[moe] fix MoE bugs (#1628)
* remove forced FP32 modules * correct no_shard-contexts' positions
This commit is contained in:
@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
|
||||
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
|
||||
layer = layer.to(get_current_device())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
|
||||
@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
|
||||
# save all results
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
o_gt_grad = layer.gate_weight.grad.data.clone()
|
||||
|
||||
# reset all gradients
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
layer.gate_weight.grad.zero_()
|
||||
|
||||
layer.use_kernel = True
|
||||
new_out = layer(tokens) # get ouputs through colossal kernel
|
||||
@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
|
||||
new_out.backward(grad) # get new type gradient
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
n_gt_grad = layer.gate_weight.grad.data.clone()
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
|
Reference in New Issue
Block a user