[moe] fix moe bugs (#1633)

This commit is contained in:
HELSON
2022-09-23 15:33:57 +08:00
committed by GitHub
parent 702dbc5288
commit a088022efc
8 changed files with 287 additions and 249 deletions

View File

@@ -32,7 +32,7 @@ def run_test(rank, world_size, port):
moe_layer = MoeLayer(DIM, num_experts, router, exp)
layer_list.append(moe_layer)
model = nn.Sequential(*layer_list)
model = nn.ModuleList(layer_list)
model = model.to(get_current_device())
sync_moe_model_param(model)
@@ -49,8 +49,9 @@ def run_test(rank, world_size, port):
grad = torch.randn_like(data)
MOE_CONTEXT.reset_loss()
outputs = model(data)
outputs.backward(grad)
for layer in layer_list:
data, _ = layer(data)
data.backward(grad)
grad_handler.handle_gradient()
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)