[chore] solve moe ckpt test failure and some other arg pass failure

This commit is contained in:
hxwang
2024-07-22 03:40:34 +00:00
committed by Hongxin Liu
parent 52d346f2a5
commit 70c9924d0d
12 changed files with 101 additions and 79 deletions

View File

@@ -12,7 +12,7 @@ from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.moe_utils import assert_loose_close
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
@@ -22,7 +22,7 @@ TOP_K = 1
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@parameterize("ep_size", [2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):
dtype = torch.bfloat16
@@ -76,7 +76,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
# torch-ddp forward
ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
loose_close(zero_output, ddp_output, dtype=dtype)
assert_loose_close(zero_output, ddp_output, dtype=dtype)
# torch-ddp backward
ddp_output.backward()
@@ -87,7 +87,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
@@ -97,7 +97,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
# check updated param
for n, p in zero_model.named_parameters():
loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
@@ -107,6 +107,7 @@ def run_dist(rank, world_size, port):
run_zero_with_original_model()
@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()