[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-12 18:17:05 +08:00
committed by GitHub
parent f1a3a326c4
commit b2483c8e31
27 changed files with 633 additions and 83 deletions

View File

@@ -206,6 +206,7 @@ class ChatGLMPipelineForwards:
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@@ -213,6 +214,7 @@ class ChatGLMPipelineForwards:
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
@@ -245,6 +247,7 @@ class ChatGLMPipelineForwards:
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@@ -252,6 +255,7 @@ class ChatGLMPipelineForwards:
dim=0,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -414,6 +418,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
inputs_embeds,
dim=0,
process_group=sp_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(
@@ -421,6 +426,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
dim=0,
process_group=sp_group,
grad_scale=1 / sp_size,
fp8_communication=shard_config.fp8_communication,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
@@ -436,6 +442,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@@ -443,6 +450,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
dim=0,
process_group=sp_group,
grad_scale=sp_size,
fp8_communication=shard_config.fp8_communication,
)
if not return_dict:
@@ -532,9 +540,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
key_layer = key_layer.reshape(sq, bs, -1)
value_layer = value_layer.reshape(sq, bs, -1)
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
query_layer = all_to_all_comm(
query_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
key_layer = all_to_all_comm(
key_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
value_layer = all_to_all_comm(
value_layer,
sp_group,
gather_dim=0,
fp8_communication=shard_config.fp8_communication,
)
query_layer = query_layer.view(
sq * sp_size,
@@ -610,7 +633,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
if sp_mode == "all_to_all":
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
context_layer = all_to_all_comm(
context_layer,
sp_group,
gather_dim=2,
scatter_dim=0,
fp8_communication=shard_config.fp8_communication,
)
# =================
# Output. [sq, b, h]