mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference]Adapted to the triton attn kernels (#5264)
* adapted to the triton attn kernels * fix pad input * adapted to copy_kv_to_blocked_cache * fix ci test * update kv memcpy * remove print
This commit is contained in:
@@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||
"""
|
||||
Func: copy key/value into key/value cache.
|
||||
@@ -40,6 +41,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||
return cache
|
||||
|
||||
|
||||
@torch.no_grad
|
||||
def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
||||
"""
|
||||
Func: convert key/value cache for calculation
|
||||
@@ -79,6 +81,7 @@ class PagedAttention:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size):
|
||||
"""
|
||||
Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size]
|
||||
@@ -94,12 +97,14 @@ class PagedAttention:
|
||||
return padded_tensor
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def generate_padding_mask(lengths, max_seq_len):
|
||||
range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len)
|
||||
padding_mask = range_tensor < lengths.unsqueeze(1)
|
||||
return padding_mask
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
@@ -117,6 +122,7 @@ class PagedAttention:
|
||||
return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim)
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def nopad_context_forward(
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||
@@ -185,6 +191,7 @@ class PagedAttention:
|
||||
return attn_output
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def pad_context_forward(
|
||||
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
||||
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
||||
@@ -239,11 +246,10 @@ class PagedAttention:
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
|
||||
|
||||
del attn_weights
|
||||
|
||||
return attn_output
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def pad_decoding_forward(
|
||||
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
||||
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
||||
@@ -297,11 +303,10 @@ class PagedAttention:
|
||||
raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.")
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1)
|
||||
|
||||
del attn_weights
|
||||
|
||||
return attn_output
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad
|
||||
def no_pad_decoding_forward(
|
||||
self,
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
|
Reference in New Issue
Block a user