mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[moe] fix tests
This commit is contained in:
@@ -4,15 +4,21 @@ import torch
|
||||
from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["router", "num_groups"], [
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
# (TopKRouter(num_selected_experts=3), 4),
|
||||
])
|
||||
@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
["router", "num_groups"],
|
||||
[
|
||||
(Top1Router(), 1),
|
||||
(Top2Router(), 1),
|
||||
# (TopKRouter(num_selected_experts=3), 4),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
["batch_size", "seq_len", "num_experts"],
|
||||
[
|
||||
(4, 5, 8),
|
||||
(3, 4, 4),
|
||||
],
|
||||
)
|
||||
def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
|
||||
x = torch.randn((batch_size * seq_len, num_experts)).cuda()
|
||||
if num_groups > 1:
|
||||
@@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex
|
||||
|
||||
router.train()
|
||||
if isinstance(router, TopKRouter):
|
||||
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
_, combine_array, dispatch_mask = router(x)
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
||||
router.eval()
|
||||
if isinstance(router, TopKRouter):
|
||||
_, combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
combine_array, dispatch_mask = router(x, expert_capacity=2)
|
||||
else:
|
||||
_, combine_array, dispatch_mask = router(x)
|
||||
combine_array, dispatch_mask = router(x)[1:3]
|
||||
assert combine_array.shape[:-1] == x.shape
|
||||
assert dispatch_mask.shape[:-1] == x.shape
|
||||
assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
|
||||
|
Reference in New Issue
Block a user