[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark
This commit is contained in:
botbw
2024-09-10 17:30:53 +08:00
committed by GitHub
parent 8fd25d6e09
commit c54c4fcd15
21 changed files with 907 additions and 99 deletions

View File

@@ -13,42 +13,25 @@ from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
CHECKED_CONFIG = [ # FOR WORLD=4
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
NUM_HEADS = 8
TOP_K = 2
@parameterize(
"config",
[
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
def run_mixtral_commom(config: Tuple[int, ...]):
Randomizer.reset_index()
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
dtype, precision = torch.bfloat16, "bf16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
@@ -165,7 +148,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
check_model_equal(torch_model, saved_model, dtype=dtype)
dist.barrier()
if rank == world_size - 1:
@@ -174,17 +157,78 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 2, 2, 1),
# zero 1
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 2, 1, 1, 2),
# zero 2
(2, 4, 1, 1, 1),
(2, 1, 4, 1, 1),
(2, 1, 1, 4, 1),
(2, 2, 1, 1, 2),
],
)
def run_mixtral_test(config: Tuple[int, ...]):
run_mixtral_commom(config)
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 2, 4, 1),
(0, 1, 4, 2, 1),
(0, 1, 1, 4, 1),
(0, 1, 4, 1, 1),
# zero 1:
(1, 2, 1, 1, 2),
(1, 2, 1, 4, 1),
(1, 1, 1, 2, 2),
(1, 2, 2, 2, 1),
# zero 2
(2, 2, 1, 1, 2),
(2, 2, 1, 4, 1),
(2, 1, 1, 2, 2),
(2, 2, 2, 2, 1),
],
)
def run_mixtral_3d_test(config: Tuple[int, ...]):
print(f"{config=}")
run_mixtral_commom(config)
def check_mixtral(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
run_mixtral_test()
def check_mixtral_3d(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mixtral_3d_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mixtral(world_size):
spawn(run_dist, world_size)
spawn(check_mixtral, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_mixtral_3d(world_size):
spawn(check_mixtral_3d, world_size)
if __name__ == "__main__":
test_mixtral(world_size=4)
test_mixtral(world_size=8)
test_mixtral_3d(world_size=8)