mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -206,6 +206,7 @@ class ChatGLMPipelineForwards:
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
@@ -213,6 +214,7 @@ class ChatGLMPipelineForwards:
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=1 / shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
@@ -245,6 +247,7 @@ class ChatGLMPipelineForwards:
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif shard_config.sequence_parallelism_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
@@ -252,6 +255,7 @@ class ChatGLMPipelineForwards:
|
||||
dim=0,
|
||||
process_group=shard_config.sequence_parallel_process_group,
|
||||
grad_scale=shard_config.sequence_parallel_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -414,6 +418,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||
inputs_embeds,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
@@ -421,6 +426,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=1 / sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
||||
inputs_embeds,
|
||||
@@ -436,6 +442,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
@@ -443,6 +450,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
|
||||
dim=0,
|
||||
process_group=sp_group,
|
||||
grad_scale=sp_size,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
@@ -532,9 +540,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
||||
key_layer = key_layer.reshape(sq, bs, -1)
|
||||
value_layer = value_layer.reshape(sq, bs, -1)
|
||||
|
||||
query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
|
||||
key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
|
||||
value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
|
||||
query_layer = all_to_all_comm(
|
||||
query_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
key_layer = all_to_all_comm(
|
||||
key_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
value_layer = all_to_all_comm(
|
||||
value_layer,
|
||||
sp_group,
|
||||
gather_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
query_layer = query_layer.view(
|
||||
sq * sp_size,
|
||||
@@ -610,7 +633,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
|
||||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
|
||||
if sp_mode == "all_to_all":
|
||||
context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
|
||||
context_layer = all_to_all_comm(
|
||||
context_layer,
|
||||
sp_group,
|
||||
gather_dim=2,
|
||||
scatter_dim=0,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
|
Reference in New Issue
Block a user