[fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
This commit is contained in:
Wang Binluo
2024-08-12 18:17:05 +08:00
committed by GitHub
parent f1a3a326c4
commit b2483c8e31
27 changed files with 633 additions and 83 deletions

View File

@@ -185,6 +185,7 @@ class GPTJPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
# Going through held blocks.
@@ -236,6 +237,7 @@ class GPTJPipelineForwards:
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
if stage_manager.is_last_stage():
@@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
fp8_communication=shard_config.fp8_communication,
)
hidden_states = self.ln_f(hidden_states)