mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fp8]Moe support fp8 communication (#5977)
* fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from colossalai.moe._operation import (
|
||||
all_to_all_uneven,
|
||||
)
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization.fp8 import all_reduce_fp8
|
||||
from colossalai.shardformer.layer._operation import (
|
||||
all_to_all_comm,
|
||||
gather_forward_split_backward,
|
||||
@@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
|
||||
def __init__(self):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
def setup_process_groups(
|
||||
self,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
@@ -70,6 +77,7 @@ class EPDeepseekMoE(nn.Module):
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.num_experts = self.config.n_routed_experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
@@ -86,9 +94,15 @@ class EPDeepseekMoE(nn.Module):
|
||||
self.tp_group = tp_group
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
|
||||
expert.gate_proj = Linear1D_Col.from_native_module(
|
||||
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.up_proj = Linear1D_Col.from_native_module(
|
||||
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.down_proj = Linear1D_Row.from_native_module(
|
||||
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
@@ -106,7 +120,8 @@ class EPDeepseekMoE(nn.Module):
|
||||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
fp8_communication = kwargs.get("fp8_communication", False)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -137,11 +152,21 @@ class EPDeepseekMoE(nn.Module):
|
||||
for i in range(1, self.ep_size):
|
||||
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
|
||||
activate_experts = (activate_experts > 0).float()
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
if self.fp8_communication:
|
||||
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
|
||||
else:
|
||||
dist.all_reduce(activate_experts, group=self.moe_dp_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states, _ = all_to_all_uneven(
|
||||
dispatch_states,
|
||||
input_split_list,
|
||||
output_split_list,
|
||||
self.ep_group,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
|
||||
if output_states.size(0) > 0:
|
||||
@@ -167,7 +192,9 @@ class EPDeepseekMoE(nn.Module):
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
dispatch_states, _ = all_to_all_uneven(
|
||||
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
||||
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
||||
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
|
||||
@@ -534,9 +561,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
@@ -595,7 +622,9 @@ def get_deepseek_flash_attention_forward(shard_config, sp_mode=None, sp_size=Non
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
) # (1, 4, 256)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
@@ -685,9 +714,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
||||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@@ -731,9 +764,13 @@ def get_deepseek_flash_attention_model_forward(shard_config, sp_mode=None, sp_si
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
Reference in New Issue
Block a user