mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[moe]: fix ep/tp tests, add hierarchical all2all (#4982)
* fix: add warning for EP different behavior * fix: use shard_data in ep & tp model * to: add used_capacity * fix: fix router test * feat: add create_ep_node_group * feat: add create_ep_hierarchical_group fn * feat: add HierarchicalAllToAll * test: add hierarchical all2all test * fix: fix test errors * fix: simplify create_ep_hierarchical_group * fix: add hierarchical_alltoall arg * fix: fix environ typo * revert: revert process mesh order * to: add todo mark * fix: skip hierarchical_comm if torch < 1.13.1
This commit is contained in:
@@ -132,8 +132,10 @@ def parse_args():
|
||||
# load balance
|
||||
parser.add_argument("--load_balance", action="store_true")
|
||||
|
||||
# overlap
|
||||
parser.add_argument("--overlap_alltoall", action="store_true")
|
||||
# overlap communication
|
||||
parser.add_argument("--overlap_comm", action="store_true")
|
||||
# hierarchical all-to-all
|
||||
parser.add_argument("--hierarchical_alltoall", action="store_true")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -211,7 +213,8 @@ def main():
|
||||
moe_layer_interval=config.moe_layer_interval,
|
||||
enable_load_balance=args.load_balance,
|
||||
enable_kernel=args.use_kernel,
|
||||
enable_comm_overlap=args.overlap_alltoall,
|
||||
enable_comm_overlap=args.overlap_comm,
|
||||
enable_hierarchical_alltoall=args.hierarchical_alltoall,
|
||||
)
|
||||
with skip_init():
|
||||
model = OpenMoeForCausalLM(config)
|
||||
|
@@ -70,6 +70,7 @@ def set_openmoe_args(
|
||||
load_balance_group_swap_factor: float = 0.4,
|
||||
enable_kernel: bool = False,
|
||||
enable_comm_overlap: bool = False,
|
||||
enable_hierarchical_alltoall: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
MoE related arguments.
|
||||
@@ -96,6 +97,7 @@ def set_openmoe_args(
|
||||
load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4.
|
||||
enable_kernel (bool, optional): Use kernel optimization. Defaults to False.
|
||||
enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False.
|
||||
enable_hierarchical_alltoall (bool, optional): Use hierarchical alltoall for MoE. Defaults to False.
|
||||
"""
|
||||
moe_args = dict(
|
||||
num_experts=num_experts,
|
||||
@@ -117,6 +119,7 @@ def set_openmoe_args(
|
||||
load_balance_group_swap_factor=load_balance_group_swap_factor,
|
||||
enable_kernel=enable_kernel,
|
||||
enable_comm_overlap=enable_comm_overlap,
|
||||
enable_hierarchical_alltoall=enable_hierarchical_alltoall,
|
||||
)
|
||||
set_moe_args(config, moe_args)
|
||||
|
||||
|
@@ -190,6 +190,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
# hierarchical all-to-all
|
||||
parser.add_argument(
|
||||
"--hierarchical_alltoall",
|
||||
action="store_true",
|
||||
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -277,6 +283,7 @@ def main():
|
||||
z_loss_factor=args.z_loss_factor,
|
||||
enable_load_balance=args.load_balance,
|
||||
enable_comm_overlap=args.comm_overlap,
|
||||
enable_hierarchical_alltoall=args.hierarchical_alltoall,
|
||||
enable_kernel=args.use_kernel,
|
||||
)
|
||||
with skip_init():
|
||||
|
Reference in New Issue
Block a user