mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[moe] test deepseek
This commit is contained in:
@@ -14,21 +14,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.test_moe.moe_utils import loose_close
|
||||
|
||||
NUM_BATCH=4
|
||||
NUM_BATCH = 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS=2
|
||||
NUM_HEADS = 2
|
||||
TOP_K = 1
|
||||
|
||||
def split_grad(grad, world_size):
|
||||
with torch.no_grad():
|
||||
grad = grad.clone().detach().flatten()
|
||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
||||
if padding_size > 0:
|
||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
||||
splited_grad = grad.split(grad.numel() // world_size)
|
||||
return splited_grad
|
||||
|
||||
|
||||
@parameterize("stage", [1])
|
||||
@parameterize("ep_size", [1, 2, 4])
|
||||
@@ -39,12 +30,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
ep_size=ep_size,
|
||||
zero_stage=stage,
|
||||
overlap_communication=False,
|
||||
initial_scale=1
|
||||
pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
@@ -81,7 +67,9 @@ def run_zero_with_original_model(stage: int, ep_size: int):
|
||||
zero_model.train()
|
||||
for _ in range(2):
|
||||
# zero-dp forward
|
||||
input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
|
||||
input_data = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output)
|
||||
|
Reference in New Issue
Block a user