mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[moe] finalize test (no pp)
This commit is contained in:
@@ -109,6 +109,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if ep_size <= 1:
|
||||
raise ValueError("Use HybridParallelPlugin when ep_size <= 1")
|
||||
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
|
||||
@@ -128,12 +131,12 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.ddp_config["find_unused_parameters"] = True
|
||||
|
||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||
# TODO it might make sense to support non-moe with tp on but moe with tp off
|
||||
raise ValueError(
|
||||
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to set ep_size=1 or zero_stage > 0"
|
||||
f"if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin or set zero_stage > 0"
|
||||
)
|
||||
|
||||
# set ep_group after super().__init__()
|
||||
# TODO do it in a better way
|
||||
# set param group in shard config
|
||||
self.shard_config.ep_group = self.ep_group
|
||||
self.shard_config.moe_dp_group = self.moe_dp_group
|
||||
self.shard_config.moe_tp_group = self.moe_tp_group
|
||||
@@ -149,9 +152,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
# when sequence parallelism is enabled, ep_group reuses sp_group
|
||||
if self.ep_size != self.sp_size:
|
||||
raise ValueError(
|
||||
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} when sequence parallelism is enabled"
|
||||
f"ep_size={self.ep_size} should be equal to sp_size={self.sp_size} or turned off when sequence parallelism is enabled"
|
||||
)
|
||||
|
||||
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
|
||||
self.moe_dp_size = self.dp_size
|
||||
self.moe_dp_group = self.dp_group # NOTE: sequence of value assignment matters
|
||||
self.dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
@@ -165,7 +169,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
else:
|
||||
self.moe_dp_size = world_size // (self.pp_size * self.ep_size * self.moe_tp_size)
|
||||
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size * self.sp_size != world_size:
|
||||
if self.moe_dp_size * self.pp_size * self.ep_size * self.moe_tp_size != world_size:
|
||||
raise ValueError(
|
||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||
)
|
||||
@@ -214,8 +218,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.moe_tp_group = group
|
||||
|
||||
if dist.get_process_group_ranks(self.tp_group) != dist.get_process_group_ranks(self.moe_tp_group):
|
||||
# NOTE: different tp settings between moe and non moe param require complex comm logic, where all_to_all might not be suitable
|
||||
# this assertion implies that dp_size == moe_dp_size * ep_size
|
||||
# NOTE: different tp settings between moe and non moe param are complex to handle
|
||||
# we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size
|
||||
raise NotImplementedError(
|
||||
f"Only support shared tp group between moe and non moe params, but found non-moe tp {dist.get_process_group_ranks(self.tp_group)}, moe tp {dist.get_process_group_ranks(self.moe_tp_group)}, please make sure tp_size == moe_tp_size"
|
||||
)
|
||||
|
Reference in New Issue
Block a user