mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[moe] fix moe bugs (#1633)
This commit is contained in:
@@ -32,7 +32,7 @@ def run_test(rank, world_size, port):
|
||||
moe_layer = MoeLayer(DIM, num_experts, router, exp)
|
||||
layer_list.append(moe_layer)
|
||||
|
||||
model = nn.Sequential(*layer_list)
|
||||
model = nn.ModuleList(layer_list)
|
||||
model = model.to(get_current_device())
|
||||
sync_moe_model_param(model)
|
||||
|
||||
@@ -49,8 +49,9 @@ def run_test(rank, world_size, port):
|
||||
grad = torch.randn_like(data)
|
||||
|
||||
MOE_CONTEXT.reset_loss()
|
||||
outputs = model(data)
|
||||
outputs.backward(grad)
|
||||
for layer in layer_list:
|
||||
data, _ = layer(data)
|
||||
data.backward(grad)
|
||||
grad_handler.handle_gradient()
|
||||
|
||||
assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
|
||||
|
Reference in New Issue
Block a user