mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[chore] solve moe ckpt test failure and some other arg pass failure
This commit is contained in:
@@ -14,8 +14,7 @@ from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||
@@ -25,18 +24,21 @@ NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
# TODO only need to keep one or two cases
|
||||
CHECKED_CONFIG = [ # FOR_WORLD=8
|
||||
(2, 1, 1, 4, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
]
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(2, 1, 1, 4, 1),
|
||||
# (2, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||
# (2, 1, 2, 2, 1), # TODO debug deepseek pp
|
||||
(2, 1, 1, 2, 1),
|
||||
# (2, 1, 1, 1, 2), # TODO support deepseek sp
|
||||
# (2, 1, 4, 1, 1), # TODO debug deepseek pp
|
||||
(4, 1, 1, 1, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
# (4, 1, 2, 1, 1), # TODO debug deepseek pp
|
||||
],
|
||||
)
|
||||
@@ -66,9 +68,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# init model with the same seed
|
||||
seed_all(10086)
|
||||
|
||||
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||
config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
|
||||
config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
|
||||
@@ -79,6 +78,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
config.n_routed_experts = NUM_EXPERTS
|
||||
config.num_experts_per_tok = TOP_K
|
||||
|
||||
# init model with the same seed
|
||||
seed_all(10086)
|
||||
|
||||
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
@@ -148,7 +150,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_mixtral"
|
||||
@@ -175,7 +177,7 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
@@ -15,8 +15,7 @@ from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||
@@ -25,20 +24,21 @@ HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
CHECKED_CONFIG = [ # FOR WORLD=4
|
||||
(2, 1, 2, 2, 1),
|
||||
(2, 1, 1, 2, 1),
|
||||
(2, 1, 4, 1, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 2, 1, 1),
|
||||
(2, 1, 2, 1, 1),
|
||||
]
|
||||
|
||||
|
||||
# TODO only need to keep one or two cases
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(2, 1, 1, 4, 1),
|
||||
(2, 1, 2, 1, 1),
|
||||
(2, 1, 2, 2, 1),
|
||||
(2, 1, 1, 2, 1),
|
||||
(2, 1, 1, 1, 2),
|
||||
(2, 1, 4, 1, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 2, 1, 1),
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
@@ -67,9 +67,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# init model with the same seed
|
||||
seed_all(10086)
|
||||
|
||||
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
||||
config = MixtralConfig(
|
||||
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
||||
@@ -82,6 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
# init model with the same seed
|
||||
seed_all(10086)
|
||||
|
||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
@@ -151,7 +151,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
torch_optimizer.step()
|
||||
torch_optimizer.zero_grad()
|
||||
|
||||
loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_mixtral"
|
||||
@@ -178,7 +178,7 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
Reference in New Issue
Block a user