[fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-08-09 18:26:02 +08:00
committed by GitHub
parent e4aadeee20
commit f1a3a326c4
8 changed files with 160 additions and 52 deletions

View File

@@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
def __init__(self, *args, **kwargs):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.ep_group = ep_group
self.fp8_communication = fp8_communication
if self.num_experts % self.ep_size != 0:
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
@@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
self.tp_group = tp_group
if self.tp_group.size() > 1:
for expert in held_experts:
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
expert.w1 = Linear1D_Col.from_native_module(
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
)
expert.w3 = Linear1D_Col.from_native_module(
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
)
expert.w2 = Linear1D_Row.from_native_module(
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
# TODO: better init
LazyInitContext.materialize(module)
module.__class__ = EPMixtralSparseMoeBlock
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
output_split_sizes = torch.zeros_like(input_split_sizes)
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
with torch.no_grad():
@@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
# compute expert output
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
@@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_experts_idx = torch.empty_like(selected_experts_idx)
recover_experts_idx[selected_experts_idx] = torch.arange(
@@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@@ -780,9 +803,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
@@ -831,9 +858,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states: