mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
* [example] pass use_fp8_comm flag to all plugins * [example] add mixtral benchmark * [moe] refine assertion and check * [moe] fix mixtral & add more tests * [moe] consider checking dp * sp group and moe_dp_group * [mixtral] remove gate tp & add more tests * [deepseek] fix tp & sp for deepseek * [mixtral] minor fix * [deepseek] add deepseek benchmark
This commit is contained in:
@@ -1,4 +1,12 @@
|
||||
import os
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from time import sleep
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
||||
@@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
return torch.allclose(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def check_model_equal(model1, model2):
|
||||
def check_model_equal(model1, model2, dtype):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
|
||||
assert_loose_close(p1, p2, p1.dtype)
|
||||
assert_loose_close(p1, p2, dtype, name=name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True):
|
||||
if enable:
|
||||
assert (
|
||||
os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1"
|
||||
), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}"
|
||||
if funcs_to_patch is None:
|
||||
funcs_to_patch = [
|
||||
dist.all_reduce,
|
||||
dist.all_reduce_coalesced,
|
||||
dist.all_gather,
|
||||
dist.all_gather_coalesced,
|
||||
dist.all_gather_into_tensor,
|
||||
dist.all_to_all,
|
||||
dist.all_to_all_single,
|
||||
dist.reduce_scatter,
|
||||
]
|
||||
|
||||
original_funcs = {}
|
||||
patched_funcs = {}
|
||||
|
||||
def make_patched(func):
|
||||
def patched_func(*args, **kwargs):
|
||||
stack = traceback.format_stack()
|
||||
|
||||
def format_node(node):
|
||||
if isinstance(node, torch.Tensor):
|
||||
return f"{node.shape}"
|
||||
elif isinstance(node, list):
|
||||
return f"[{', '.join([format_node(n) for n in node])}]"
|
||||
|
||||
return str(node)
|
||||
|
||||
args_str, kwargs_str = tree_map(format_node, (args, kwargs))
|
||||
en = len(stack) - 1
|
||||
st = max(0, en - num_stacks)
|
||||
dist.barrier()
|
||||
sleep(0.001 * dist.get_rank())
|
||||
print(
|
||||
f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n"
|
||||
)
|
||||
dist.barrier()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return patched_func
|
||||
|
||||
if enable:
|
||||
for func in funcs_to_patch:
|
||||
original_funcs[func.__name__] = getattr(dist, func.__name__)
|
||||
patched_funcs[func.__name__] = make_patched(func)
|
||||
setattr(dist, func.__name__, patched_funcs[func.__name__])
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for func_name, original_func in original_funcs.items():
|
||||
setattr(dist, func_name, original_func)
|
||||
|
Reference in New Issue
Block a user