[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-12 18:17:05 +08:00
committed by GitHub
parent f1a3a326c4
commit b2483c8e31
27 changed files with 633 additions and 83 deletions

View File

@@ -140,6 +140,7 @@ class CommandPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
@@ -147,6 +148,7 @@ class CommandPipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=1 / shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# decoder layers
@@ -211,6 +213,7 @@ class CommandPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = gather_forward_split_backward(
@@ -218,6 +221,7 @@ class CommandPipelineForwards:
dim=1,
process_group=shard_config.sequence_parallel_process_group,
grad_scale=shard_config.sequence_parallel_size,
fp8_communication=shard_config.fp8_communication,
)
# add hidden states from the last decoder layer
@@ -382,9 +386,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@@ -446,7 +450,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -526,9 +532,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
hidden_states = inputs_embeds
# decoder layers
@@ -573,9 +583,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
hidden_states = self.norm(hidden_states)
if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)
# add hidden states from the last decoder layer
if output_hidden_states: