[moe] fix moe bugs (#1633)

This commit is contained in:
HELSON
2022-09-23 15:33:57 +08:00
committed by GitHub
parent 702dbc5288
commit a088022efc
8 changed files with 287 additions and 249 deletions

View File

@@ -44,7 +44,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# use matrix multiplication instead of COL_MOE_KERNL in MOE dispatch and combine
layer.use_kernel = False
old_out = layer(tokens)
old_out, _ = layer(tokens)
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad) # get gradient
@@ -58,7 +58,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.gate_weight.grad.zero_()
layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel
new_out, _ = layer(tokens) # get ouputs through colossal kernel
if data_type == torch.float32:
check_equal(old_out, new_out)