mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[moe]: fix ep/tp tests, add hierarchical all2all (#4982)
* fix: add warning for EP different behavior * fix: use shard_data in ep & tp model * to: add used_capacity * fix: fix router test * feat: add create_ep_node_group * feat: add create_ep_hierarchical_group fn * feat: add HierarchicalAllToAll * test: add hierarchical all2all test * fix: fix test errors * fix: simplify create_ep_hierarchical_group * fix: add hierarchical_alltoall arg * fix: fix environ typo * revert: revert process mesh order * to: add todo mark * fix: skip hierarchical_comm if torch < 1.13.1
This commit is contained in:
@@ -190,6 +190,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
# hierarchical all-to-all
|
||||
parser.add_argument(
|
||||
"--hierarchical_alltoall",
|
||||
action="store_true",
|
||||
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -277,6 +283,7 @@ def main():
|
||||
z_loss_factor=args.z_loss_factor,
|
||||
enable_load_balance=args.load_balance,
|
||||
enable_comm_overlap=args.comm_overlap,
|
||||
enable_hierarchical_alltoall=args.hierarchical_alltoall,
|
||||
enable_kernel=args.use_kernel,
|
||||
)
|
||||
with skip_init():
|
||||
|
Reference in New Issue
Block a user