mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -175,6 +175,7 @@ class Qwen2PipelineForwards:
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
@@ -182,6 +183,7 @@ class Qwen2PipelineForwards:
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
@@ -246,6 +248,7 @@ class Qwen2PipelineForwards:
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
@@ -253,6 +256,7 @@ class Qwen2PipelineForwards:
|
||||
dim=1,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
@@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
value_states = self.v_proj(hidden_states)
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
@@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
next_decoder_cache = None
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
Reference in New Issue
Block a user