[shardformer] support ep for deepseek v3 (#6185)

* [feature] support ep for deepseek v3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* [shardformer] fix deepseek v3 init

* [lazy] fit lora for lazy init

* [example] support npu for deepseek v3

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2025-02-11 16:10:25 +08:00
committed by GitHub
parent 17062c83b9
commit 2b415e5999
13 changed files with 612 additions and 22 deletions

View File

@@ -223,7 +223,6 @@ def run_forward_backward_with_hybrid_plugin(
for k, v in data.items():
unshard_test_data[k] = data[k].clone()
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in shard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
@@ -248,7 +247,6 @@ def run_forward_backward_with_hybrid_plugin(
sharded_loss = criterion(sharded_output)
sharded_optimizer.backward(sharded_loss)
org_model.train()
if booster.plugin.stage_manager is not None:
for k, v in unshard_test_data.items():
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:

View File

@@ -0,0 +1,102 @@
from typing import Tuple
import pytest
import torch
import torch.distributed
import torch.distributed as dist
from torch.testing import assert_close
import colossalai
from colossalai.booster.plugin import MoeHybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False)
seed_all(42)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin
)
if enable_gradient_checkpointing:
# org_model.gradient_checkpointing_enable()
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
org_model = org_model.to(torch.bfloat16)
org_model.eval()
sharded_model.eval()
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
assert_close(org_loss, sharded_loss)
param_dict = {n: p for n, p in org_model.named_parameters()}
for n, p in sharded_model.unwrap().named_parameters():
if n in param_dict:
if booster.plugin.zero_stage == 0:
grad = p.grad
target_grad = param_dict[n].grad
else:
grad = sharded_optimizer.get_working_grad_by_param_id(id(p))
pg = sharded_optimizer.param_to_pg[p]
target_grad = param_dict[n].grad
if target_grad is None:
continue
target_grad = target_grad.view(-1).chunk(dist.get_world_size(pg))[dist.get_rank(pg)]
assert_close(grad, target_grad, atol=3e-1, rtol=0)
@parameterize(
"config",
[
# zero 1
(1, 4),
(1, 2),
],
)
def run_deepseek_v3_test(config: Tuple[int, ...]):
zero_stage, ep_size = config
plugin_config = dict(
pp_size=1,
tp_size=1,
ep_size=ep_size,
zero_stage=zero_stage,
overlap_communication=False,
precision="bf16",
find_unused_parameters=True,
)
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek_v3")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
plugin_config,
)
def check_deepseek_v3(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_deepseek_v3_test()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_deepseek_v3(world_size):
spawn(check_deepseek_v3, world_size)
if __name__ == "__main__":
test_deepseek_v3(world_size=4)