mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[moe] refactor mesh assignment
This commit is contained in:
@@ -61,13 +61,10 @@ class EPDeepseekMoE(nn.Module):
|
||||
def __init__(self):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(
|
||||
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
|
||||
):
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
assert moe_tp_group is not None
|
||||
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
@@ -85,16 +82,13 @@ class EPDeepseekMoE(nn.Module):
|
||||
self.moe_dp_group = moe_dp_group
|
||||
self.moe_dp_size = moe_dp_group.size()
|
||||
|
||||
# setup global tp group
|
||||
# setup tp group
|
||||
self.tp_group = tp_group
|
||||
|
||||
# setup moe tp group
|
||||
self.moe_tp_group = moe_tp_group
|
||||
if self.moe_tp_group.size() > 1:
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
@@ -105,7 +99,6 @@ class EPDeepseekMoE(nn.Module):
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_tp_group: ProcessGroup,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "EPDeepseekMoE":
|
||||
@@ -113,7 +106,7 @@ class EPDeepseekMoE(nn.Module):
|
||||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@@ -53,13 +53,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(
|
||||
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
|
||||
):
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
assert moe_tp_group is not None
|
||||
|
||||
# setup ep group
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
@@ -81,14 +78,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
|
||||
# setup global tp group
|
||||
self.tp_group = tp_group
|
||||
|
||||
# setup moe tp group
|
||||
self.moe_tp_group = moe_tp_group
|
||||
if self.moe_tp_group.size() > 1:
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.moe_tp_group)
|
||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
|
||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
|
||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
|
||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
|
||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
@@ -99,14 +93,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
moe_tp_group: ProcessGroup,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> "EPMixtralSparseMoeBlock":
|
||||
# TODO: better init
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@@ -154,7 +154,6 @@ class DeepseekPolicy(Policy):
|
||||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"moe_tp_group": self.shard_config.moe_tp_group,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@@ -155,7 +155,6 @@ class MixtralPolicy(Policy):
|
||||
"ep_group": self.shard_config.ep_group,
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"moe_tp_group": self.shard_config.moe_tp_group,
|
||||
},
|
||||
)
|
||||
],
|
||||
|
@@ -50,7 +50,6 @@ class ShardConfig:
|
||||
# for moe related
|
||||
moe_dp_group: Optional[ProcessGroup] = None
|
||||
ep_group: Optional[ProcessGroup] = None
|
||||
moe_tp_group: Optional[ProcessGroup] = None
|
||||
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
|
Reference in New Issue
Block a user