mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark
This commit is contained in:
@@ -1,4 +1,12 @@
|
||||
import os
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from time import sleep
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||
@@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
return torch.allclose(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def check_model_equal(model1, model2):
|
||||
def check_model_equal(model1, model2, dtype):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||
assert_loose_close(p1, p2, p1.dtype)
|
||||
assert_loose_close(p1, p2, dtype, name=name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True):
|
||||
if enable:
|
||||
assert (
|
||||
os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1"
|
||||
), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}"
|
||||
if funcs_to_patch is None:
|
||||
funcs_to_patch = [
|
||||
dist.all_reduce,
|
||||
dist.all_reduce_coalesced,
|
||||
dist.all_gather,
|
||||
dist.all_gather_coalesced,
|
||||
dist.all_gather_into_tensor,
|
||||
dist.all_to_all,
|
||||
dist.all_to_all_single,
|
||||
dist.reduce_scatter,
|
||||
]
|
||||
|
||||
original_funcs = {}
|
||||
patched_funcs = {}
|
||||
|
||||
def make_patched(func):
|
||||
def patched_func(*args, **kwargs):
|
||||
stack = traceback.format_stack()
|
||||
|
||||
def format_node(node):
|
||||
if isinstance(node, torch.Tensor):
|
||||
return f"{node.shape}"
|
||||
elif isinstance(node, list):
|
||||
return f"[{', '.join([format_node(n) for n in node])}]"
|
||||
|
||||
return str(node)
|
||||
|
||||
args_str, kwargs_str = tree_map(format_node, (args, kwargs))
|
||||
en = len(stack) - 1
|
||||
st = max(0, en - num_stacks)
|
||||
dist.barrier()
|
||||
sleep(0.001 * dist.get_rank())
|
||||
print(
|
||||
f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n"
|
||||
)
|
||||
dist.barrier()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return patched_func
|
||||
|
||||
if enable:
|
||||
for func in funcs_to_patch:
|
||||
original_funcs[func.__name__] = getattr(dist, func.__name__)
|
||||
patched_funcs[func.__name__] = make_patched(func)
|
||||
setattr(dist, func.__name__, patched_funcs[func.__name__])
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for func_name, original_func in original_funcs.items():
|
||||
setattr(dist, func_name, original_func)
|
||||
|
@@ -130,7 +130,7 @@ def check_moe_checkpoint(test_config):
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
|
||||
check_model_equal(orig_model, saved_model)
|
||||
check_model_equal(orig_model, saved_model, dtype=dtype)
|
||||
saved_model.save_pretrained(hf_model_dir)
|
||||
dist.barrier()
|
||||
# check load model
|
||||
@@ -138,7 +138,7 @@ def check_moe_checkpoint(test_config):
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, hf_model_dir)
|
||||
check_model_equal(model, new_model)
|
||||
check_model_equal(model, new_model, dtype=dtype)
|
||||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
|
@@ -12,43 +12,25 @@ from transformers import AutoConfig, AutoModel
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
|
||||
NUM_LAYERS = 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
NUM_HEADS = 8
|
||||
TOP_K = 2
|
||||
|
||||
|
||||
CHECKED_CONFIG = [ # FOR_WORLD=4
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 1, 1, 1, 4),
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 1, 1, 4),
|
||||
(1, 2, 1, 1, 1),
|
||||
]
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(1, 2, 2, 1, 1),
|
||||
(1, 2, 1, 2, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
def run_deepseek_commom(config: Tuple[int, ...]):
|
||||
Randomizer.reset_index()
|
||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
dtype, precision = torch.bfloat16, "bf16"
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
@@ -60,11 +42,11 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
enable_flash_attention=sp_size > 1,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision=precision,
|
||||
find_unused_parameters=True,
|
||||
enable_flash_attention=True,
|
||||
)
|
||||
dp_size = plugin.dp_size
|
||||
|
||||
@@ -171,7 +153,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
dist.barrier()
|
||||
|
||||
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
|
||||
check_model_equal(torch_model, saved_model)
|
||||
check_model_equal(torch_model, saved_model, dtype=dtype)
|
||||
dist.barrier()
|
||||
|
||||
if rank == world_size - 1:
|
||||
@@ -180,17 +162,77 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
print(f"rank {dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 2, 2, 1),
|
||||
# zero 1
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
# zero 2
|
||||
(2, 4, 1, 1, 1),
|
||||
(2, 1, 4, 1, 1),
|
||||
(2, 1, 1, 4, 1),
|
||||
(2, 2, 1, 1, 2),
|
||||
],
|
||||
)
|
||||
def run_deepseek_test(config: Tuple[int, ...]):
|
||||
run_deepseek_commom(config)
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||
(0, 1, 2, 4, 1),
|
||||
(0, 1, 4, 2, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 4, 1, 1),
|
||||
# zero 1:
|
||||
(1, 2, 1, 1, 2),
|
||||
(1, 2, 1, 4, 1),
|
||||
(1, 1, 1, 2, 2),
|
||||
(1, 2, 2, 2, 1),
|
||||
# zero 2
|
||||
(2, 2, 1, 1, 2),
|
||||
(2, 2, 1, 4, 1),
|
||||
(2, 1, 1, 2, 2),
|
||||
(2, 2, 2, 2, 1),
|
||||
],
|
||||
)
|
||||
def run_deepseek_3d_test(config: Tuple[int, ...]):
|
||||
run_deepseek_commom(config)
|
||||
|
||||
|
||||
def check_deepseek(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model()
|
||||
run_deepseek_test()
|
||||
|
||||
|
||||
def check_deepseek_3d(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_deepseek_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_deepseek(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
spawn(check_deepseek, world_size)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_deepseek_3d(world_size):
|
||||
spawn(check_deepseek_3d, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_deepseek(world_size=4)
|
||||
test_deepseek(world_size=8)
|
||||
test_deepseek_3d(world_size=8)
|
||||
|
@@ -13,42 +13,25 @@ from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
||||
import colossalai
|
||||
from colossalai.booster.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
||||
|
||||
NUM_BATCH = 8
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
|
||||
NUM_LAYERS = 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
CHECKED_CONFIG = [ # FOR WORLD=4
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 1, 1, 4),
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 1, 1, 1, 4),
|
||||
(1, 2, 1, 1, 1),
|
||||
]
|
||||
NUM_HEADS = 8
|
||||
TOP_K = 2
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(1, 2, 2, 1, 1),
|
||||
(1, 2, 1, 2, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
def run_mixtral_commom(config: Tuple[int, ...]):
|
||||
Randomizer.reset_index()
|
||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
dtype, precision = torch.bfloat16, "bf16"
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
@@ -165,7 +148,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
dist.barrier()
|
||||
|
||||
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
||||
check_model_equal(torch_model, saved_model)
|
||||
check_model_equal(torch_model, saved_model, dtype=dtype)
|
||||
dist.barrier()
|
||||
|
||||
if rank == world_size - 1:
|
||||
@@ -174,17 +157,78 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
print(f"rank {dist.get_rank()} test passed")
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 2, 2, 1),
|
||||
# zero 1
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
# zero 2
|
||||
(2, 4, 1, 1, 1),
|
||||
(2, 1, 4, 1, 1),
|
||||
(2, 1, 1, 4, 1),
|
||||
(2, 2, 1, 1, 2),
|
||||
],
|
||||
)
|
||||
def run_mixtral_test(config: Tuple[int, ...]):
|
||||
run_mixtral_commom(config)
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
# DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
|
||||
(0, 1, 2, 4, 1),
|
||||
(0, 1, 4, 2, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 4, 1, 1),
|
||||
# zero 1:
|
||||
(1, 2, 1, 1, 2),
|
||||
(1, 2, 1, 4, 1),
|
||||
(1, 1, 1, 2, 2),
|
||||
(1, 2, 2, 2, 1),
|
||||
# zero 2
|
||||
(2, 2, 1, 1, 2),
|
||||
(2, 2, 1, 4, 1),
|
||||
(2, 1, 1, 2, 2),
|
||||
(2, 2, 2, 2, 1),
|
||||
],
|
||||
)
|
||||
def run_mixtral_3d_test(config: Tuple[int, ...]):
|
||||
print(f"{config=}")
|
||||
run_mixtral_commom(config)
|
||||
|
||||
|
||||
def check_mixtral(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_zero_with_original_model()
|
||||
run_mixtral_test()
|
||||
|
||||
|
||||
def check_mixtral_3d(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_mixtral_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mixtral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
spawn(check_mixtral, world_size)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mixtral_3d(world_size):
|
||||
spawn(check_mixtral_3d, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral(world_size=4)
|
||||
test_mixtral(world_size=8)
|
||||
test_mixtral_3d(world_size=8)
|
||||
|
Reference in New Issue
Block a user