[FP8] rebase main (#5963)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
This commit is contained in:
flybird11111
2024-08-06 16:29:37 +08:00
committed by GitHub
parent 53cb9606bd
commit 0c10afd372
208 changed files with 10962 additions and 2892 deletions

View File

@@ -1,238 +1,132 @@
import os
import warnings
from typing import Dict
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.booster.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close
# from colossalai.shardformer.layer import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 2
def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from local model
@parameterize("stage", [1])
@parameterize("ep_size", [2])
def run_zero_with_original_model(stage: int, ep_size: int):
tp_size = dist.get_world_size() // ep_size
dtype = torch.bfloat16
Args:
tp_model (MoeModule)
local_model (MoeModule)
"""
for (tp_name, tp_param), (local_name, local_param) in zip(
tp_model.named_parameters(), local_model.named_parameters()
):
assert tp_name == local_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
assert torch.allclose(tp_param, local_param)
assert torch.allclose(tp_param.grad, local_param.grad)
else:
tp_param.data.copy_(local_param.data)
continue
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
tp_rank = get_ep_rank(tp_param)
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0]
tp_slice = [slice(None)] * tp_dim + [
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
]
seed_all(10086)
if assert_grad_flag:
assert torch.allclose(tp_param, local_param[tuple(tp_slice)])
assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)])
else:
tp_param.data.copy_(local_param[tuple(tp_slice)].data)
def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
tp_model (MoeModule)
ep_model (MoeModule)
"""
for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
assert tp_name == ep_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
assert torch.allclose(tp_param, ep_param)
assert torch.allclose(tp_param.grad, ep_param.grad)
else:
tp_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
# get tp param
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1
tp_rank = get_ep_rank(tp_param)
tp_slice = [slice(None)] * tp_dim + [
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
]
new_tp_param = all_param[tuple(tp_slice)]
if assert_grad_flag:
new_grad = all_grad[tuple(tp_slice)]
if assert_grad_flag:
assert torch.allclose(tp_param, new_tp_param)
assert torch.allclose(tp_param.grad, new_grad)
else:
tp_param.data.copy_(new_tp_param.data)
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name == ep_name
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param)
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
if enable_hierarchical_comm:
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
ep_model = SparseMLP(
num_experts=num_experts,
hidden_size=dim,
intermediate_size=dim * 2,
enable_hierarchical_comm=enable_hierarchical_comm,
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=2,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_accelerator().get_current_device())
tp_model = tp_model.to(get_accelerator().get_current_device())
local_model = local_model.to(get_accelerator().get_current_device())
torch_model = MixtralModel(config).to(dtype).cuda()
# sync ep param
sync_moe_model_param(ep_model)
dist_dict = MOE_MANAGER.parallel_info_dict
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
ep_grad_handler = MoeGradientHandler(ep_model)
# sync local param
sync_local_from_ep(local_model, ep_model)
# sync tp param
sync_tp_from_ep(tp_model, ep_model)
tp_grad_handler = MoeGradientHandler(tp_model)
rank = dist.get_rank()
input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
micro_batch_size = batch_size // world_size
index = rank * micro_batch_size
# NOTE: ep & tp takes in sharded data for each process
shard_data = input_data.detach()[index : index + micro_batch_size]
out_local = local_model(input_data)
MOE_MANAGER.reset_loss()
out_tp = tp_model(shard_data)
MOE_MANAGER.reset_loss()
out_ep = ep_model(shard_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(
out_tp, out_ep, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
try:
out_local_slice = out_local[index : index + micro_batch_size]
assert torch.allclose(
out_ep, out_local_slice, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
except AssertionError:
"""
e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
However, in ep mode, there are 2 separate routers dealing with sharded data.
Assume router 0 handles token [01] and router 1 handles token [23].
Note that for each router the capacity is only 1 !!!
Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both.
The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
"""
warnings.warn(
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
moe_booster = Booster(
plugin=MoeHybridParallelPlugin(
tp_size=tp_size,
moe_tp_size=tp_size,
pp_size=1,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
)
)
zero_model, zero_optimizer, _, _, _ = moe_booster.boost(zero_model, zero_optimizer)
out_local.mean().backward()
out_tp.mean().backward()
tp_grad_handler.handle_gradient()
out_ep.mean().backward()
ep_grad_handler.handle_gradient()
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
try:
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
except AssertionError:
warnings.warn(
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
hybird_booster = Booster(
plugin=HybridParallelPlugin(
tp_size=tp_size,
pp_size=1,
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
)
)
hybrid_model, hybrid_optimizer, _, _, _ = hybird_booster.boost(
torch_model, torch.optim.SGD(torch_model.parameters(), lr=1)
)
# create different input
seed_all(1453 + rank)
hybrid_model.train()
zero_model.train()
for _ in range(2):
# zero-dp forward
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)
# torch-ddp forward
hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
assert_loose_close(zero_output, hybrid_output, dtype=dtype)
# torch-ddp backward
hybrid_optimizer.backward(hybrid_output)
# check grad
name_to_p = {n: p for n, p in hybrid_model.named_parameters()}
for n, p in zero_model.named_parameters():
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n])
continue
if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe
continue
assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
# original model step
hybrid_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe
continue
assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
@pytest.mark.skip(reason="moe need to be refactored")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.dist
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize(
"config",
[
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
],
)
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
def test_moe_ep_tp(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
test_moe_ep_tp(world_size=4)