mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[chore] solve moe ckpt test failure and some other arg pass failure
This commit is contained in:
@@ -22,6 +22,7 @@ def check_deepseek_moe_layer():
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
zero_stage=1,
|
||||
ep_size=dist.get_world_size(),
|
||||
)
|
||||
|
||||
@@ -42,7 +43,13 @@ def check_deepseek_moe_layer():
|
||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
|
||||
model = EPDeepseekMoE.from_native_module(
|
||||
model,
|
||||
ep_group=plugin.ep_group,
|
||||
moe_dp_group=plugin.moe_dp_group,
|
||||
moe_tp_group=plugin.moe_tp_group,
|
||||
tp_group=plugin.tp_group,
|
||||
)
|
||||
ep_output = model(x)
|
||||
assert_close(orig_output, ep_output)
|
||||
orig_loss = orig_output.mean()
|
||||
@@ -62,7 +69,7 @@ def run_dist(rank: int, world_size: int, port: int):
|
||||
check_deepseek_moe_layer()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.skip("tested in corresponding sharderformer")
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
def test_deepseek_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
Reference in New Issue
Block a user