[chore] solve moe ckpt test failure and some other arg pass failure

This commit is contained in:
hxwang
2024-07-22 03:40:34 +00:00
committed by Hongxin Liu
parent 52d346f2a5
commit 70c9924d0d
12 changed files with 101 additions and 79 deletions

View File

@@ -22,6 +22,7 @@ def check_deepseek_moe_layer():
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=1,
ep_size=dist.get_world_size(),
)
@@ -42,7 +43,13 @@ def check_deepseek_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output = orig_model(x)
model = deepcopy(orig_model)
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
model = EPDeepseekMoE.from_native_module(
model,
ep_group=plugin.ep_group,
moe_dp_group=plugin.moe_dp_group,
moe_tp_group=plugin.moe_tp_group,
tp_group=plugin.tp_group,
)
ep_output = model(x)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
@@ -62,7 +69,7 @@ def run_dist(rank: int, world_size: int, port: int):
check_deepseek_moe_layer()
# @pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.parametrize("world_size", [2])
def test_deepseek_moe_layer(world_size: int):
spawn(run_dist, world_size)