mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[fp8]Moe support fp8 communication (#5977)
* fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -53,7 +53,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
|
||||
|
||||
def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
|
||||
def setup_process_groups(
|
||||
self,
|
||||
tp_group: ProcessGroup,
|
||||
moe_dp_group: ProcessGroup,
|
||||
ep_group: ProcessGroup,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert tp_group is not None
|
||||
assert moe_dp_group is not None
|
||||
assert ep_group is not None
|
||||
@@ -62,6 +68,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
self.ep_size = dist.get_world_size(ep_group)
|
||||
self.ep_rank = dist.get_rank(ep_group)
|
||||
self.ep_group = ep_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if self.num_experts % self.ep_size != 0:
|
||||
raise ValueError("The number of experts must be divisible by the number of expert parallel groups.")
|
||||
@@ -80,9 +87,15 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
self.tp_group = tp_group
|
||||
if self.tp_group.size() > 1:
|
||||
for expert in held_experts:
|
||||
expert.w1 = Linear1D_Col.from_native_module(expert.w1, self.tp_group)
|
||||
expert.w3 = Linear1D_Col.from_native_module(expert.w3, self.tp_group)
|
||||
expert.w2 = Linear1D_Row.from_native_module(expert.w2, self.tp_group)
|
||||
expert.w1 = Linear1D_Col.from_native_module(
|
||||
expert.w1, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.w3 = Linear1D_Col.from_native_module(
|
||||
expert.w3, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
expert.w2 = Linear1D_Row.from_native_module(
|
||||
expert.w2, self.tp_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_ep_group(p, ep_group)
|
||||
@@ -99,7 +112,8 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
# TODO: better init
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
|
||||
fp8_communication = kwargs.get("fp8_communication", False)
|
||||
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication)
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -120,6 +134,7 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -132,7 +147,13 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states, _ = all_to_all_uneven(
|
||||
dispatch_states,
|
||||
input_split_list,
|
||||
output_split_list,
|
||||
self.ep_group,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
# compute expert output
|
||||
output_states = EPGradScalerIn.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
@@ -162,7 +183,9 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
output_states = torch.cat(output_states_list)
|
||||
|
||||
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
dispatch_states, _ = all_to_all_uneven(
|
||||
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
@@ -566,9 +589,9 @@ def get_mixtral_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -780,9 +803,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||
)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
@@ -831,9 +858,13 @@ def get_mixtral_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
Reference in New Issue
Block a user