mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[moe] fix moe bugs (#1633)
This commit is contained in:
@@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
|
||||
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
|
||||
layer.use_kernel = False
|
||||
old_out = layer(tokens)
|
||||
old_out, _ = layer(tokens)
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad) # get gradient
|
||||
@@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
layer.gate_weight.grad.zero_()
|
||||
|
||||
layer.use_kernel = True
|
||||
new_out = layer(tokens) # get ouputs through colossal kernel
|
||||
new_out, _ = layer(tokens) # get ouputs through colossal kernel
|
||||
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
|
Reference in New Issue
Block a user