[Inference]Adapted to the triton attn kernels (#5264)

* adapted to the triton attn kernels

* fix pad input

* adapted to copy_kv_to_blocked_cache

* fix ci test

* update kv memcpy

* remove print
This commit is contained in:
yuehuayingxueluo
2024-01-17 16:03:10 +08:00
committed by GitHub
parent 0f2b46a41c
commit 86b63f720c
7 changed files with 221 additions and 101 deletions

View File

@@ -6,6 +6,7 @@ import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
@torch.no_grad
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
"""
Func: copy key/value into key/value cache.
@@ -40,6 +41,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
return cache
@torch.no_grad
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
@@ -79,6 +81,7 @@ class PagedAttention:
"""
@staticmethod
@torch.no_grad
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
"""
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
@@ -94,12 +97,14 @@ class PagedAttention:
return padded_tensor
@staticmethod
@torch.no_grad
def generate_padding_mask(lengths, max_seq_len):
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
padding_mask = range_tensor < lengths.unsqueeze(1)
return padding_mask
@staticmethod
@torch.no_grad
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
"""
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
@@ -117,6 +122,7 @@ class PagedAttention:
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
@staticmethod
@torch.no_grad
def nopad_context_forward(
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
@@ -185,6 +191,7 @@ class PagedAttention:
return attn_output
@staticmethod
@torch.no_grad
def pad_context_forward(
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
@@ -239,11 +246,10 @@ class PagedAttention:
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
del attn_weights
return attn_output
@staticmethod
@torch.no_grad
def pad_decoding_forward(
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
@@ -297,11 +303,10 @@ class PagedAttention:
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
del attn_weights
return attn_output
@staticmethod
@torch.no_grad
def no_pad_decoding_forward(
self,
q: torch.Tensor, # [num_tokens, num_heads, head_size]