[moe]: fix ep/tp tests, add hierarchical all2all (#4982)

* fix: add warning for EP different behavior

* fix: use shard_data in ep & tp model

* to: add used_capacity

* fix: fix router test

* feat: add create_ep_node_group

* feat: add create_ep_hierarchical_group fn

* feat: add HierarchicalAllToAll

* test: add hierarchical all2all test

* fix: fix test errors

* fix: simplify create_ep_hierarchical_group

* fix: add hierarchical_alltoall arg

* fix: fix environ typo

* revert: revert process mesh order

* to: add todo mark

* fix: skip hierarchical_comm if torch < 1.13.1
This commit is contained in:
Wenhao Chen
2023-11-09 14:31:00 +08:00
committed by GitHub
parent 239cd92eff
commit 724441279b
10 changed files with 388 additions and 164 deletions

View File

@@ -132,8 +132,10 @@ def parse_args():
# load balance
parser.add_argument("--load_balance", action="store_true")
# overlap
parser.add_argument("--overlap_alltoall", action="store_true")
# overlap communication
parser.add_argument("--overlap_comm", action="store_true")
# hierarchical all-to-all
parser.add_argument("--hierarchical_alltoall", action="store_true")
args = parser.parse_args()
return args
@@ -211,7 +213,8 @@ def main():
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=args.load_balance,
enable_kernel=args.use_kernel,
enable_comm_overlap=args.overlap_alltoall,
enable_comm_overlap=args.overlap_comm,
enable_hierarchical_alltoall=args.hierarchical_alltoall,
)
with skip_init():
model = OpenMoeForCausalLM(config)

View File

@@ -70,6 +70,7 @@ def set_openmoe_args(
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_alltoall: bool = False,
) -> None:
"""
MoE related arguments.
@@ -96,6 +97,7 @@ def set_openmoe_args(
load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4.
enable_kernel (bool, optional): Use kernel optimization. Defaults to False.
enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False.
enable_hierarchical_alltoall (bool, optional): Use hierarchical alltoall for MoE. Defaults to False.
"""
moe_args = dict(
num_experts=num_experts,
@@ -117,6 +119,7 @@ def set_openmoe_args(
load_balance_group_swap_factor=load_balance_group_swap_factor,
enable_kernel=enable_kernel,
enable_comm_overlap=enable_comm_overlap,
enable_hierarchical_alltoall=enable_hierarchical_alltoall,
)
set_moe_args(config, moe_args)

View File

@@ -190,6 +190,12 @@ def parse_args():
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
)
# hierarchical all-to-all
parser.add_argument(
"--hierarchical_alltoall",
action="store_true",
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
)
args = parser.parse_args()
return args
@@ -277,6 +283,7 @@ def main():
z_loss_factor=args.z_loss_factor,
enable_load_balance=args.load_balance,
enable_comm_overlap=args.comm_overlap,
enable_hierarchical_alltoall=args.hierarchical_alltoall,
enable_kernel=args.use_kernel,
)
with skip_init():