[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)

* add fused qkv

* replace attn and mlp by shardformer

* fix bugs in mlp

* add docstrings

* fix test_inference_engine.py

* add optimize unbind

* add fused_addmm

* rm squeeze(1)

* refactor codes

* fix ci bugs

* rename ShardFormerLlamaMLP and ShardFormerLlamaAttention

* Removed the dependency on LlamaFlashAttention2

* rollback test_inference_engine.py
This commit is contained in:
yuehuayingxueluo
2024-02-01 15:49:39 +08:00
committed by GitHub
parent f8e456d202
commit 249644c23b
8 changed files with 510 additions and 341 deletions

View File

@@ -2,8 +2,10 @@
from typing import List, Optional, Tuple
import torch
from torch.nn import Parameter
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
@@ -39,6 +41,14 @@ def llama_causal_lm_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaForCausalLM.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
@@ -46,7 +56,7 @@ def llama_causal_lm_forward(
k_caches=k_caches,
v_caches=v_caches,
)
logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
logits = torch.mm(hidden_states, self.lm_head.weight)
return logits
@@ -57,6 +67,13 @@ def llama_model_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaModel.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
@@ -74,7 +91,7 @@ def llama_model_forward(
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
@@ -116,12 +133,30 @@ def llama_decoder_layer_forward(
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
residual=residual,
block_tables=block_tables,
k_cache=k_cache,
v_cache=v_cache,
@@ -134,88 +169,213 @@ def llama_decoder_layer_forward(
sm_scale=sm_scale,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.mlp(hidden_states, residual)
return hidden_states
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def llama_attn_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)
value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
-1, self.num_key_value_heads, self.head_dim
)
class NopadLlamaAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
layer_idx: Optional[int] = None,
attn_qproj_w: torch.Tensor = None,
attn_kproj_w: torch.Tensor = None,
attn_vproj_w: torch.Tensor = None,
attn_oproj_w: torch.Tensor = None,
):
"""This layer will replace the LlamaAttention.
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
Args:
config (LlamaConfig): Holding the Llama model config.
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
"""
super().__init__(config, layer_idx)
self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False)
self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False)
self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False)
self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False)
if self.num_heads == self.num_key_value_heads:
qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
self.q_proj = None
self.k_proj = None
self.v_proj = None
block_size = k_cache.size(-2)
@staticmethod
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
Args:
module (LlamaAttention): The origin LlamaAttention layer.
"""
config = module.config
layer_idx = module.layer_idx
attn_qproj_w = module.q_proj.weight.transpose(0, 1)
attn_kproj_w = module.k_proj.weight.transpose(0, 1)
attn_vproj_w = module.v_proj.weight.transpose(0, 1)
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
attn_layer = NopadLlamaAttention(
config=config,
layer_idx=layer_idx,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
return attn_layer
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
if self.num_heads != self.num_key_value_heads:
query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim)
key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
else:
# fused qkv
token_nums = hidden_states.size(0)
hidden_states = hidden_states.expand(3, -1, -1)
query_states, key_states, value_states = (
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
return attn_output
# NOTE This will cause the result to be different from the transformer in some cases.
class NopadLlamaMLP(LlamaMLP):
def __init__(
self,
config: LlamaConfig,
mlp_gproj_w: torch.Tensor = None,
mlp_uproj_w: torch.Tensor = None,
mlp_dproj_w: torch.Tensor = None,
):
"""This layer will replace the LlamaAttention.
Args:
config (LlamaConfig): Holding the Llama model config.
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
Args:
module (LlamaMLP): The origin LlamaMLP layer.
"""
config = module.config
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
mlp_layer = NopadLlamaMLP(
config=config,
mlp_gproj_w=mlp_gproj_w,
mlp_uproj_w=mlp_uproj_w,
mlp_dproj_w=mlp_dproj_w,
)
attn_output = attn_output.squeeze(1)
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))
return mlp_layer
return attn_output
@torch.no_grad()
def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
tmp_out = act_out * up_proj_out
return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
"""
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
tmp_out = act_out * up_proj_out
return torch.addmm(residual, tmp_out, self.down_proj.weight)

View File

@@ -2,7 +2,13 @@
from typing import List, Optional, Tuple
import torch
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
)
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.layers.attention import PagedAttention
@@ -53,6 +59,14 @@ def llama_causal_lm_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaForCausalLM.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
@@ -71,6 +85,13 @@ def llama_model_forward(
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
"""This function will replace the forward function of LlamaModel.
Args:
batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
input_ids = batch.get_batch_inputs()
block_tables = batch.get_block_table_tensor()
attention_mask = batch.get_attn_mask()
@@ -110,7 +131,7 @@ def llama_model_forward(
)
else:
output_tensor = torch.zeros(
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
(batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
@@ -131,7 +152,8 @@ def llama_model_forward(
sm_scale=sm_scale,
)
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
if batch.is_prompts:
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
hidden_states = self.norm(hidden_states)
return hidden_states
@@ -154,6 +176,23 @@ def llama_decoder_layer_forward(
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
Args:
hidden_states (torch.Tensor): _description_
position_ids (torch.LongTensor), The position ids of input sequences.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@@ -185,108 +224,192 @@ def llama_decoder_layer_forward(
return hidden_states
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def llama_attn_forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
class PadLlamaAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
layer_idx: Optional[int] = None,
attn_qproj_w: torch.nn.Parameter = None,
attn_kproj_w: torch.nn.Parameter = None,
attn_vproj_w: torch.nn.Parameter = None,
attn_oproj_w: torch.nn.Parameter = None,
):
"""This layer will replace the LlamaAttention.
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
Args:
config (LlamaConfig): Holding the Llama model config.
layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None.
attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None.
attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None.
attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None.
"""
super().__init__(config, layer_idx)
self.q_proj.weight = attn_qproj_w
self.k_proj.weight = attn_kproj_w
self.v_proj.weight = attn_vproj_w
self.o_proj.weight = attn_oproj_w
if HAS_TRITON:
if is_prompts:
if attention_mask is not None:
query_states, key_states, value_states, indices = unpading_input(
query_states, key_states, value_states, attention_mask
@staticmethod
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention
Args:
module (LlamaAttention): The origin LlamaAttention layer.
"""
config = module.config
layer_idx = module.layer_idx
attn_qproj_w = module.q_proj.weight
attn_kproj_w = module.k_proj.weight
attn_vproj_w = module.v_proj.weight
attn_oproj_w = module.o_proj.weight
attn_layer = PadLlamaAttention(
config=config,
layer_idx=layer_idx,
attn_qproj_w=attn_qproj_w,
attn_kproj_w=attn_kproj_w,
attn_vproj_w=attn_vproj_w,
attn_oproj_w=attn_oproj_w,
)
return attn_layer
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
block_tables: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
is_prompts: bool = True,
sequence_lengths: torch.Tensor = None,
attention_mask: torch.Tensor = None,
kv_seq_len: int = 0,
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
position_ids (torch.LongTensor), The position ids of input sequences.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)`
where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens.
kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
if HAS_TRITON:
if is_prompts:
if attention_mask is not None:
query_states, key_states, value_states, indices = unpading_input(
query_states, key_states, value_states, attention_mask
)
else:
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
else:
query_states = query_states.squeeze(dim=1)
key_states = key_states.squeeze(dim=1)
value_states = value_states.squeeze(dim=1)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if is_prompts:
attn_output = PagedAttention.pad_context_forward(
query_states,
key_states,
value_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
attention_mask,
)
else:
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
else:
query_states = query_states.squeeze(dim=1)
key_states = key_states.squeeze(dim=1)
value_states = value_states.squeeze(dim=1)
attn_output = PagedAttention.pad_decoding_forward(
query_states,
key_states,
value_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
attention_mask,
)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
v=value_states,
k_cache=k_cache,
v_cache=v_cache,
context_lengths=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
)
if attention_mask is not None:
attn_output = pad_input(attn_output, indices, bsz, q_len)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
)
attn_output = attn_output.squeeze(1)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if is_prompts:
attn_output = PagedAttention.pad_context_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
else:
attn_output = PagedAttention.pad_decoding_forward(
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
)
attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
return attn_output
@torch.no_grad()