mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -140,6 +140,7 @@ class CommandPipelineForwards:
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
@@ -147,6 +148,7 @@ class CommandPipelineForwards:
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
@@ -211,6 +213,7 @@ class CommandPipelineForwards:
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
@@ -218,6 +221,7 @@ class CommandPipelineForwards:
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
@@ -382,9 +386,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -446,7 +450,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
@@ -526,9 +532,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
@@ -573,9 +583,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
Reference in New Issue
Block a user