[moe] fix tests

This commit is contained in:
ver217
2024-02-08 12:46:37 +08:00
parent 65e5d6baa5
commit 06db94fbc9
4 changed files with 31 additions and 39 deletions

View File

@@ -4,15 +4,21 @@ import torch
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
@pytest.mark.parametrize(["router", "num_groups"], [
(Top1Router(), 1),
(Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4),
])
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
(4, 5, 8),
(3, 4, 4),
])
@pytest.mark.parametrize(
["router", "num_groups"],
[
(Top1Router(), 1),
(Top2Router(), 1),
# (TopKRouter(num_selected_experts=3), 4),
],
)
@pytest.mark.parametrize(
["batch_size", "seq_len", "num_experts"],
[
(4, 5, 8),
(3, 4, 4),
],
)
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
if num_groups > 1:
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
router.train()
if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
_, combine_array, dispatch_mask = router(x)
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
router.eval()
if isinstance(router, TopKRouter):
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
combine_array, dispatch_mask = router(x, expert_capacity=2)
else:
_, combine_array, dispatch_mask = router(x)
combine_array, dispatch_mask = router(x)[1:3]
assert combine_array.shape[:-1] == x.shape
assert dispatch_mask.shape[:-1] == x.shape
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)