[moe] test deepseek

This commit is contained in:
hxwang
2024-07-16 10:10:40 +00:00
committed by Hongxin Liu
parent dc583aa576
commit 74eccac0db
10 changed files with 276 additions and 68 deletions

View File

@@ -14,21 +14,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import loose_close
NUM_BATCH=4
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS=2
NUM_HEADS = 2
TOP_K = 1
def split_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
@parameterize("stage", [1])
@parameterize("ep_size", [1, 2, 4])
@@ -39,12 +30,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=1,
tp_size=1,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1
pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
)
booster = Booster(plugin=plugin)
@@ -81,7 +67,9 @@ def run_zero_with_original_model(stage: int, ep_size: int):
zero_model.train()
for _ in range(2):
# zero-dp forward
input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)