diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 4d99e48d3..e40674c9b 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -47,7 +47,7 @@ class MoeRouter(nn.Module, ABC): def get_capacity(self, num_tokens, num_experts, ep_group=None): if ep_group is not None: - num_tokens_tensor = torch.tensor(num_tokens, device=get_current_device()) + num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device()) dist.all_reduce(num_tokens_tensor, group=ep_group) num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 511eb26e8..a2433d1b2 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -911,11 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.copy_(working_moe_param) + if hasattr(self, "master_moe_params"): + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.copy_(working_moe_param) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: - return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} + if hasattr(self, "moe_master_to_working_map"): + return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} + return self._param_store.master_to_working_param diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 8f51e1663..d6dad2d7f 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -12,7 +12,6 @@ import colossalai from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn sys.path.append( @@ -95,6 +94,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=1, zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -103,6 +103,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=dist.get_world_size(), zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -111,6 +112,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=2, zero_stage=2, extra_dp_size=2, custom_policy=OpenMoeForCausalLMPolicy(), @@ -120,6 +122,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=2, + ep_size=2, zero_stage=1, microbatch_size=1, custom_policy=OpenMoeForCausalLMPolicy(), @@ -130,27 +133,6 @@ def get_model(parallel): def _test_moe_checkpoint(rank, parallel): - if parallel == None: - MOE_MANAGER.setup( - parallel=None, - ) - elif parallel == "ep": - MOE_MANAGER.setup( - parallel="EP", - ) - elif parallel == "ep_zero": - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=2, - ) - elif parallel == "hybrid": - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=1, - fixed_ep_size=2, - fixed_pp_size=2, - ) model1, booster1, optim1 = get_model(parallel) model2, booster2, optim2 = get_model(parallel) model3, booster3, optim3 = get_model(parallel) @@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel): _test_moe_checkpoint(rank, parallel) +@pytest.mark.skip(reason="This is tested in ColossalMOE") @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"]) diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py index 7ba7fa6f6..9f6167692 100644 --- a/tests/test_moe/test_moe_router.py +++ b/tests/test_moe/test_moe_router.py @@ -4,15 +4,21 @@ import torch from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter -@pytest.mark.parametrize(["router", "num_groups"], [ - (Top1Router(), 1), - (Top2Router(), 1), - # (TopKRouter(num_selected_experts=3), 4), -]) -@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ - (4, 5, 8), - (3, 4, 4), -]) +@pytest.mark.parametrize( + ["router", "num_groups"], + [ + (Top1Router(), 1), + (Top2Router(), 1), + # (TopKRouter(num_selected_experts=3), 4), + ], +) +@pytest.mark.parametrize( + ["batch_size", "seq_len", "num_experts"], + [ + (4, 5, 8), + (3, 4, 4), + ], +) def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): x = torch.randn((batch_size * seq_len, num_experts)).cuda() if num_groups > 1: @@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex router.train() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) router.eval() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)