[moe] refactor mesh assignment

This commit is contained in:
hxwang
2024-07-25 06:19:54 +00:00
committed by Hongxin Liu
parent 034020bd04
commit cb01c0d5ce
10 changed files with 277 additions and 170 deletions

View File

@@ -61,13 +61,10 @@ class EPDeepseekMoE(nn.Module):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
):
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
assert moe_tp_group is not None
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
@@ -85,16 +82,13 @@ class EPDeepseekMoE(nn.Module):
self.moe_dp_group = moe_dp_group
self.moe_dp_size = moe_dp_group.size()
# setup global tp group
# setup tp group
self.tp_group = tp_group
# setup moe tp group
self.moe_tp_group = moe_tp_group
if self.moe_tp_group.size() > 1:
if self.tp_group.size() > 1:
for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@@ -105,7 +99,6 @@ class EPDeepseekMoE(nn.Module):
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPDeepseekMoE":
@@ -113,7 +106,7 @@ class EPDeepseekMoE(nn.Module):
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@@ -53,13 +53,10 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(
self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
):
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
assert moe_tp_group is not None
# setup ep group
self.ep_size = dist.get_world_size(ep_group)
@@ -81,14 +78,11 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# setup global tp group
self.tp_group = tp_group
# setup moe tp group
self.moe_tp_group = moe_tp_group
if self.moe_tp_group.size() > 1:
if self.tp_group.size() > 1:
for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.moe_tp_group)
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.moe_tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.moe_tp_group)
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@@ -99,14 +93,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_tp_group: ProcessGroup,
*args,
**kwargs,
) -> "EPMixtralSparseMoeBlock":
# TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@@ -154,7 +154,6 @@ class DeepseekPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"moe_tp_group": self.shard_config.moe_tp_group,
},
)
],

View File

@@ -155,7 +155,6 @@ class MixtralPolicy(Policy):
"ep_group": self.shard_config.ep_group,
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"moe_tp_group": self.shard_config.moe_tp_group,
},
)
],

View File

@@ -50,7 +50,6 @@ class ShardConfig:
# for moe related
moe_dp_group: Optional[ProcessGroup] = None
ep_group: Optional[ProcessGroup] = None
moe_tp_group: Optional[ProcessGroup] = None
# pipeline_parallel_size: int
# data_parallel_size: int