mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
72
tests/test_moe/test_deepseek_layer.py
Normal file
72
tests/test_moe/test_deepseek_layer.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_deepseek_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size(),
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
"deepseek-ai/deepseek-moe-16b-base",
|
||||
num_hidden_layers=1,
|
||||
n_routed_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
first_k_dense_replace=0,
|
||||
num_attention_heads=2,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
# get the moe layer in auto model
|
||||
orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()
|
||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
|
||||
ep_output = model(x)
|
||||
assert_close(orig_output, ep_output)
|
||||
orig_loss = orig_output.mean()
|
||||
orig_loss.backward()
|
||||
ep_loss = ep_output.mean()
|
||||
ep_loss.backward()
|
||||
assert_close(orig_loss, ep_loss)
|
||||
name_to_p = {n: p for n, p in orig_model.named_parameters()}
|
||||
for n, ep_p in model.named_parameters():
|
||||
p = name_to_p[n]
|
||||
if ep_p.grad is not None:
|
||||
assert_close(p.grad, ep_p.grad)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_deepseek_moe_layer()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
def test_deepseek_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_deepseek_moe_layer(2)
|
Reference in New Issue
Block a user