mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[moe] refactor mesh assignment
This commit is contained in:
@@ -24,24 +24,28 @@ NUM_HEADS = 4
|
||||
TOP_K = 2
|
||||
|
||||
|
||||
CHECKED_CONFIG = [ # FOR_WORLD=8
|
||||
(2, 1, 1, 4, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
(2, 1, 2, 1, 1),
|
||||
CHECKED_CONFIG = [ # FOR_WORLD=4
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 1, 1, 1, 4),
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 1, 1, 4),
|
||||
(1, 2, 1, 1, 1),
|
||||
]
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(2, 1, 2, 1, 1),
|
||||
# (2, 1, 1, 2, 1),
|
||||
# (2, 1, 1, 1, 2),
|
||||
(1, 2, 2, 1, 1),
|
||||
(1, 2, 1, 2, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
ep_size, stage, pp_size, tp_size, sp_size = config
|
||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
@@ -53,7 +57,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
|
@@ -25,24 +25,28 @@ NUM_HEADS = 4
|
||||
TOP_K = 1
|
||||
|
||||
CHECKED_CONFIG = [ # FOR WORLD=4
|
||||
(2, 1, 2, 2, 1),
|
||||
(2, 1, 1, 2, 1),
|
||||
(2, 1, 4, 1, 1),
|
||||
(4, 1, 1, 1, 1),
|
||||
(4, 1, 1, 2, 1),
|
||||
(4, 1, 2, 1, 1),
|
||||
(2, 1, 2, 1, 1),
|
||||
(0, 1, 4, 1, 1),
|
||||
(0, 1, 1, 4, 1),
|
||||
(0, 1, 1, 1, 4),
|
||||
(1, 4, 1, 1, 1),
|
||||
(1, 1, 4, 1, 1),
|
||||
(1, 1, 1, 4, 1),
|
||||
(1, 1, 1, 1, 4),
|
||||
(1, 2, 1, 1, 1),
|
||||
]
|
||||
|
||||
|
||||
@parameterize(
|
||||
"config",
|
||||
[
|
||||
(2, 1, 1, 2, 1),
|
||||
(1, 2, 2, 1, 1),
|
||||
(1, 2, 1, 2, 1),
|
||||
(1, 2, 1, 1, 2),
|
||||
(0, 2, 1, 1, 1),
|
||||
],
|
||||
)
|
||||
def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
ep_size, stage, pp_size, tp_size, sp_size = config
|
||||
stage, ep_size, pp_size, tp_size, sp_size = config
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
dtype, precision = torch.float16, "fp16"
|
||||
@@ -54,7 +58,6 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||
tp_size=tp_size,
|
||||
sp_size=sp_size,
|
||||
ep_size=ep_size,
|
||||
moe_tp_size=tp_size,
|
||||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
|
Reference in New Issue
Block a user