[moe] refactor mesh assignment

This commit is contained in:
hxwang
2024-07-25 06:19:54 +00:00
committed by Hongxin Liu
parent 034020bd04
commit cb01c0d5ce
10 changed files with 277 additions and 170 deletions

View File

@@ -24,24 +24,28 @@ NUM_HEADS = 4
TOP_K = 2
CHECKED_CONFIG = [ # FOR_WORLD=8
(2, 1, 1, 4, 1),
(4, 1, 1, 2, 1),
(4, 1, 1, 1, 1),
(2, 1, 2, 1, 1),
CHECKED_CONFIG = [ # FOR_WORLD=4
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(2, 1, 2, 1, 1),
# (2, 1, 1, 2, 1),
# (2, 1, 1, 1, 2),
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
ep_size, stage, pp_size, tp_size, sp_size = config
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
@@ -53,7 +57,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,

View File

@@ -25,24 +25,28 @@ NUM_HEADS = 4
TOP_K = 1
CHECKED_CONFIG = [ # FOR WORLD=4
(2, 1, 2, 2, 1),
(2, 1, 1, 2, 1),
(2, 1, 4, 1, 1),
(4, 1, 1, 1, 1),
(4, 1, 1, 2, 1),
(4, 1, 2, 1, 1),
(2, 1, 2, 1, 1),
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(2, 1, 1, 2, 1),
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
(0, 2, 1, 1, 1),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
ep_size, stage, pp_size, tp_size, sp_size = config
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
@@ -54,7 +58,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
moe_tp_size=tp_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,