[moe] finalize test (no pp)

This commit is contained in:
hxwang
2024-07-18 13:36:18 +00:00
committed by Hongxin Liu
parent 2cddeac717
commit 7077d38d5a
2 changed files with 29 additions and 16 deletions

View File

@@ -18,28 +18,34 @@ from tests.test_moe.moe_utils import loose_close
from tests.test_moe.test_moe_checkpoint import check_model_equal
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, tp_size = config
dtype, precision = torch.float16, "fp16"
ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
print(config)
rank = torch.distributed.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=1,
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
moe_tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
)
booster = Booster(plugin=plugin)
@@ -53,6 +59,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
attn_implementation="flash_attention_2",
)
torch_model = MixtralModel(config).to(dtype).cuda()
@@ -72,7 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
zero_optimizer.backward(zero_output)
@@ -124,11 +133,11 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("world_size", [8])
@rerun_if_address_is_in_use()
def test_mistral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mistral(world_size=4)
test_mistral(world_size=8)