[fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-08-09 18:26:02 +08:00
committed by GitHub
parent e4aadeee20
commit f1a3a326c4
8 changed files with 160 additions and 52 deletions

View File

@@ -24,6 +24,7 @@ from colossalai.moe._operation import (
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
@@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
@@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.fp8_communication = fp8_communication
self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
@@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
self.tp_group = tp_group
if self.tp_group.size() > 1:
for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
expert.gate_proj = Linear1D_Col.from_native_module(
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.up_proj = Linear1D_Col.from_native_module(
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.down_proj = Linear1D_Row.from_native_module(
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
)
for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
@@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
return module
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -137,11 +152,21 @@ class EPDeepseekMoE(nn.Module):
for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)
if self.fp8_communication:
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
else:
dist.all_reduce(activate_experts, group=self.moe_dp_group)
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
if output_states.size(0) > 0:
@@ -167,7 +192,9 @@ class EPDeepseekMoE(nn.Module):
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
@@ -534,9 +561,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
@@ -595,7 +622,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -685,9 +714,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
# embed positions
hidden_states = inputs_embeds
@@ -731,9 +764,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states: