mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[Feature] llama shardformer fp8 support (#5938)
* add llama shardformer fp8 * Llama Shardformer Parity * fix typo * fix all reduce * fix pytest failure * fix reduce op and move function to fp8.py * fix typo
This commit is contained in:
@@ -460,7 +460,7 @@ class LlamaPipelineForwards:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -510,9 +510,9 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group)
|
||||
key_states = all_to_all_comm(key_states, sp_group)
|
||||
value_states = all_to_all_comm(value_states, sp_group)
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -574,7 +574,9 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
@@ -592,7 +594,7 @@ def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None,
|
||||
return forward
|
||||
|
||||
|
||||
def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def forward(
|
||||
@@ -659,9 +661,13 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
||||
attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
if sp_mode in ["ring", "split_gather"]:
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
|
||||
inputs_embeds = split_forward_gather_backward(
|
||||
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
@@ -706,9 +712,13 @@ def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if sp_mode == "ring" or sp_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
elif sp_mode == "all_to_all":
|
||||
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
|
Reference in New Issue
Block a user