[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark
This commit is contained in:
botbw
2024-09-10 17:30:53 +08:00
committed by GitHub
parent 8fd25d6e09
commit c54c4fcd15
21 changed files with 907 additions and 99 deletions

View File

@@ -1,4 +1,12 @@
import os
import traceback
from contextlib import contextmanager
from time import sleep
from typing import Callable, List, Optional
import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
@@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
return torch.allclose(a, b, rtol=rtol, atol=atol)
def check_model_equal(model1, model2):
def check_model_equal(model1, model2, dtype):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
assert_loose_close(p1, p2, p1.dtype)
assert_loose_close(p1, p2, dtype, name=name)
@contextmanager
def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True):
if enable:
assert (
os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1"
), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}"
if funcs_to_patch is None:
funcs_to_patch = [
dist.all_reduce,
dist.all_reduce_coalesced,
dist.all_gather,
dist.all_gather_coalesced,
dist.all_gather_into_tensor,
dist.all_to_all,
dist.all_to_all_single,
dist.reduce_scatter,
]
original_funcs = {}
patched_funcs = {}
def make_patched(func):
def patched_func(*args, **kwargs):
stack = traceback.format_stack()
def format_node(node):
if isinstance(node, torch.Tensor):
return f"{node.shape}"
elif isinstance(node, list):
return f"[{', '.join([format_node(n) for n in node])}]"
return str(node)
args_str, kwargs_str = tree_map(format_node, (args, kwargs))
en = len(stack) - 1
st = max(0, en - num_stacks)
dist.barrier()
sleep(0.001 * dist.get_rank())
print(
f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n"
)
dist.barrier()
return func(*args, **kwargs)
return patched_func
if enable:
for func in funcs_to_patch:
original_funcs[func.__name__] = getattr(dist, func.__name__)
patched_funcs[func.__name__] = make_patched(func)
setattr(dist, func.__name__, patched_funcs[func.__name__])
try:
yield
finally:
for func_name, original_func in original_funcs.items():
setattr(dist, func_name, original_func)

View File

@@ -130,7 +130,7 @@ def check_moe_checkpoint(test_config):
dist.barrier()
if dist.get_rank() == 0:
saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(orig_model, saved_model)
check_model_equal(orig_model, saved_model, dtype=dtype)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
@@ -138,7 +138,7 @@ def check_moe_checkpoint(test_config):
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, hf_model_dir)
check_model_equal(model, new_model)
check_model_equal(model, new_model, dtype=dtype)
# check save optimizer
optimizer.step()

View File

@@ -12,43 +12,25 @@ from transformers import AutoConfig, AutoModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
NUM_HEADS = 8
TOP_K = 2
CHECKED_CONFIG = [ # FOR_WORLD=4
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
def run_deepseek_commom(config: Tuple[int, ...]):
Randomizer.reset_index()
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
dtype, precision = torch.bfloat16, "bf16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
@@ -60,11 +42,11 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
enable_flash_attention=sp_size > 1,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
enable_flash_attention=True,
)
dp_size = plugin.dp_size
@@ -171,7 +153,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier()
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
check_model_equal(torch_model, saved_model)
check_model_equal(torch_model, saved_model, dtype=dtype)
dist.barrier()
if rank == world_size - 1:
@@ -180,17 +162,77 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 2, 2, 1),
# zero 1
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 2, 1, 1, 2),
# zero 2
(2, 4, 1, 1, 1),
(2, 1, 4, 1, 1),
(2, 1, 1, 4, 1),
(2, 2, 1, 1, 2),
],
)
def run_deepseek_test(config: Tuple[int, ...]):
run_deepseek_commom(config)
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 2, 4, 1),
(0, 1, 4, 2, 1),
(0, 1, 1, 4, 1),
(0, 1, 4, 1, 1),
# zero 1:
(1, 2, 1, 1, 2),
(1, 2, 1, 4, 1),
(1, 1, 1, 2, 2),
(1, 2, 2, 2, 1),
# zero 2
(2, 2, 1, 1, 2),
(2, 2, 1, 4, 1),
(2, 1, 1, 2, 2),
(2, 2, 2, 2, 1),
],
)
def run_deepseek_3d_test(config: Tuple[int, ...]):
run_deepseek_commom(config)
def check_deepseek(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
run_deepseek_test()
def check_deepseek_3d(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_deepseek_3d_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_deepseek(world_size):
spawn(run_dist, world_size)
spawn(check_deepseek, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_deepseek_3d(world_size):
spawn(check_deepseek_3d, world_size)
if __name__ == "__main__":
test_deepseek(world_size=4)
test_deepseek(world_size=8)
test_deepseek_3d(world_size=8)

View File

@@ -13,42 +13,25 @@ from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
CHECKED_CONFIG = [ # FOR WORLD=4
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
NUM_HEADS = 8
TOP_K = 2
@parameterize(
"config",
[
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
def run_mixtral_commom(config: Tuple[int, ...]):
Randomizer.reset_index()
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
dtype, precision = torch.bfloat16, "bf16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
@@ -165,7 +148,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
check_model_equal(torch_model, saved_model, dtype=dtype)
dist.barrier()
if rank == world_size - 1:
@@ -174,17 +157,78 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 2, 2, 1),
# zero 1
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 2, 1, 1, 2),
# zero 2
(2, 4, 1, 1, 1),
(2, 1, 4, 1, 1),
(2, 1, 1, 4, 1),
(2, 2, 1, 1, 2),
],
)
def run_mixtral_test(config: Tuple[int, ...]):
run_mixtral_commom(config)
@parameterize(
"config",
[
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
(0, 1, 2, 4, 1),
(0, 1, 4, 2, 1),
(0, 1, 1, 4, 1),
(0, 1, 4, 1, 1),
# zero 1:
(1, 2, 1, 1, 2),
(1, 2, 1, 4, 1),
(1, 1, 1, 2, 2),
(1, 2, 2, 2, 1),
# zero 2
(2, 2, 1, 1, 2),
(2, 2, 1, 4, 1),
(2, 1, 1, 2, 2),
(2, 2, 2, 2, 1),
],
)
def run_mixtral_3d_test(config: Tuple[int, ...]):
print(f"{config=}")
run_mixtral_commom(config)
def check_mixtral(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
run_mixtral_test()
def check_mixtral_3d(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mixtral_3d_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mixtral(world_size):
spawn(run_dist, world_size)
spawn(check_mixtral, world_size)
@pytest.mark.largedist
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_mixtral_3d(world_size):
spawn(check_mixtral_3d, world_size)
if __name__ == "__main__":
test_mixtral(world_size=4)
test_mixtral(world_size=8)
test_mixtral_3d(world_size=8)