mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 11:08:50 +00:00
[moe] finalize test (no pp)
This commit is contained in:
@@ -18,28 +18,34 @@ from tests.test_moe.moe_utils import loose_close
|
||||
from tests.test_moe.test_moe_checkpoint import check_model_equal
|
||||
|
||||
NUM_BATCH = 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
|
||||
NUM_TOK_PER_BATCH, NUM_EXPERTS = 8, 4
|
||||
HIDDEN_SIZE_PER_HEAD = 4
|
||||
NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
@parameterize("config", [(0, 1, 1), (0, 1, 2), (0, 1, 4), (1, 1, 4), (1, 2, 2), (1, 4, 1)])
|
||||
@parameterize("config", [(2, 1, 2, 1, 2, 1), (2, 1, 2, 1, 1, 2), (4, 1, 1, 1, 2, 1), (4, 1, 2, 1, 1, 1)])
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
stage, ep_size, tp_size = config
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
ep_size, stage, dp_size, pp_size, tp_size, sp_size = config
|
||||
print(config)
|
||||
rank = torch.distributed.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
pp_size=pp_size,
|
||||
num_microbatches=pp_size,
|
||||
tp_size=tp_size,
|
||||
moe_tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision=precision,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
@@ -53,6 +59,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
num_key_value_heads=NUM_HEADS,
|
||||
num_local_experts=NUM_EXPERTS,
|
||||
num_experts_per_tok=TOP_K,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
torch_model = MixtralModel(config).to(dtype).cuda()
|
||||
@@ -72,7 +79,9 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
input_data = torch.rand(
|
||||
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
||||
).cuda()
|
||||
dist.all_reduce(input_data, group=plugin.tp_group) # tp requires duplicate input
|
||||
|
||||
dist.all_reduce(input_data, group=plugin.tp_group) # tp group requires duplicate input
|
||||
dist.all_reduce(input_data, group=plugin.sp_group) # sp group requires duplicate input
|
||||
|
||||
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
|
||||
zero_optimizer.backward(zero_output)
|
||||
@@ -124,11 +133,11 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("world_size", [8])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mistral(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mistral(world_size=4)
|
||||
test_mistral(world_size=8)
|
||||
|
Reference in New Issue
Block a user