mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[fp8] support hybrid parallel plugin (#5982)
* support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * support fp8 comm for qwen2 model * fp8 * fix * bert and bloom * chatglm and command * gpt2,gptj,bert, falcon,blip2 * mistral,opy,sam,t5,vit,whisper * fix * fix * fix
This commit is contained in:
@@ -143,9 +143,13 @@ class LlamaPipelineForwards:
|
||||
# Support SP + PP
|
||||
if stage_manager.is_first_stage():
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
if use_cache:
|
||||
@@ -210,9 +214,13 @@ class LlamaPipelineForwards:
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
Reference in New Issue
Block a user