[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark
This commit is contained in:
botbw
2024-09-10 17:30:53 +08:00
committed by GitHub
parent 8fd25d6e09
commit c54c4fcd15
21 changed files with 907 additions and 99 deletions

View File

@@ -1,4 +1,12 @@
import os
import traceback
from contextlib import contextmanager
from time import sleep
from typing import Callable, List, Optional
import torch
import torch.distributed as dist
from torch.utils._pytree import tree_map
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
@@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
return torch.allclose(a, b, rtol=rtol, atol=atol)
def check_model_equal(model1, model2):
def check_model_equal(model1, model2, dtype):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
assert_loose_close(p1, p2, p1.dtype)
assert_loose_close(p1, p2, dtype, name=name)
@contextmanager
def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True):
if enable:
assert (
os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1"
), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}"
if funcs_to_patch is None:
funcs_to_patch = [
dist.all_reduce,
dist.all_reduce_coalesced,
dist.all_gather,
dist.all_gather_coalesced,
dist.all_gather_into_tensor,
dist.all_to_all,
dist.all_to_all_single,
dist.reduce_scatter,
]
original_funcs = {}
patched_funcs = {}
def make_patched(func):
def patched_func(*args, **kwargs):
stack = traceback.format_stack()
def format_node(node):
if isinstance(node, torch.Tensor):
return f"{node.shape}"
elif isinstance(node, list):
return f"[{', '.join([format_node(n) for n in node])}]"
return str(node)
args_str, kwargs_str = tree_map(format_node, (args, kwargs))
en = len(stack) - 1
st = max(0, en - num_stacks)
dist.barrier()
sleep(0.001 * dist.get_rank())
print(
f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n"
)
dist.barrier()
return func(*args, **kwargs)
return patched_func
if enable:
for func in funcs_to_patch:
original_funcs[func.__name__] = getattr(dist, func.__name__)
patched_funcs[func.__name__] = make_patched(func)
setattr(dist, func.__name__, patched_funcs[func.__name__])
try:
yield
finally:
for func_name, original_func in original_funcs.items():
setattr(dist, func_name, original_func)