[chore] solve moe ckpt test failure and some other arg pass failure

This commit is contained in:
hxwang
2024-07-22 03:40:34 +00:00
parent 9f9e268265
commit 05a78d2f41
12 changed files with 101 additions and 79 deletions

View File

@@ -14,8 +14,7 @@ from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
@@ -25,18 +24,21 @@ NUM_HEADS = 4
TOP_K = 1
# TODO only need to keep one or two cases
CHECKED_CONFIG = [ # FOR_WORLD=8
(2, 1, 1, 4, 1),
(4, 1, 1, 2, 1),
(4, 1, 1, 1, 1),
]
@parameterize(
"config",
[
(2, 1, 1, 4, 1),
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
(2, 1, 1, 2, 1),
# (2, 1, 1, 1, 2), # TODO support deepseek sp
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
(4, 1, 1, 1, 1),
(4, 1, 1, 2, 1),
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
],
)
@@ -66,9 +68,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
booster = Booster(plugin=plugin)
# init model with the same seed
seed_all(10086)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
@@ -79,6 +78,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
config.n_routed_experts = NUM_EXPERTS
config.num_experts_per_tok = TOP_K
# init model with the same seed
seed_all(10086)
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@@ -148,7 +150,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
torch_optimizer.step()
torch_optimizer.zero_grad()
loose_close(parallel_output, torch_output_sum, dtype=dtype)
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
@@ -175,7 +177,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [8])
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)

View File

@@ -15,8 +15,7 @@ from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
@@ -25,20 +24,21 @@ HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
CHECKED_CONFIG = [ # FOR WORLD=4
(2, 1, 2, 2, 1),
(2, 1, 1, 2, 1),
(2, 1, 4, 1, 1),
(4, 1, 1, 1, 1),
(4, 1, 1, 2, 1),
(4, 1, 2, 1, 1),
(2, 1, 2, 1, 1),
]
# TODO only need to keep one or two cases
@parameterize(
"config",
[
(2, 1, 1, 4, 1),
(2, 1, 2, 1, 1),
(2, 1, 2, 2, 1),
(2, 1, 1, 2, 1),
(2, 1, 1, 1, 2),
(2, 1, 4, 1, 1),
(4, 1, 1, 1, 1),
(4, 1, 1, 2, 1),
(4, 1, 2, 1, 1),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
@@ -67,9 +67,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
booster = Booster(plugin=plugin)
# init model with the same seed
seed_all(10086)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
@@ -82,6 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
attn_implementation="flash_attention_2",
)
# init model with the same seed
seed_all(10086)
torch_model = MixtralModel(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
@@ -151,7 +151,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
torch_optimizer.step()
torch_optimizer.zero_grad()
loose_close(parallel_output, torch_output_sum, dtype=dtype)
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
@@ -178,7 +178,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [8])
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)