mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -187,11 +187,17 @@ class BertPipelineForwards:
|
||||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
||||
@@ -242,7 +248,10 @@ class BertPipelineForwards:
|
||||
if shard_config is not None and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
if output_hidden_states:
|
||||
@@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
# split the input tensor along sequence dimension
|
||||
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
|
||||
embedding_output = split_forward_gather_backward(
|
||||
embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
embedding_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = split_forward_gather_backward(
|
||||
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
encoder_hidden_states,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
@@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
|
||||
# When sequence parallelism done, gather the output tensor in forward and split it in backward
|
||||
sequence_output = gather_forward_split_backward(
|
||||
sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
|
||||
sequence_output,
|
||||
dim=1,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
fp8_communication=shard_config.fp8_communication,
|
||||
)
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
Reference in New Issue
Block a user