mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
Refactor modeling by adding attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
146
colossalai/inference/modeling/backends/attention_backend.py
Normal file
146
colossalai/inference/modeling/backends/attention_backend.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
import torch
|
||||
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
flash_decoding_attention,
|
||||
)
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
@dataclass
|
||||
class AttentionMetaData:
|
||||
query_states: torch.Tensor
|
||||
key_states: torch.Tensor
|
||||
value_states: torch.Tensor
|
||||
k_cache: torch.Tensor
|
||||
v_cache: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
block_size: int
|
||||
kv_seq_len: int = None
|
||||
sequence_lengths: torch.Tensor = None
|
||||
cu_seqlens: torch.Tensor = None
|
||||
sm_scale: int = None
|
||||
alibi_slopes: torch.Tensor = None
|
||||
output_tensor: torch.Tensor = None
|
||||
use_spec_dec: bool = False
|
||||
use_alibi_attn: bool = False
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@abstractmethod
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec:
|
||||
token_nums = kwargs.get('token_nums', -1)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||
max_seqlen_v=attn_metadata.kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=attn_metadata.sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
attn_output = context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=True,
|
||||
)
|
||||
return attn_output
|
||||
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
|
||||
output_tensor = attn_metadata.output_tensor
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.block_size,
|
||||
attn_metadata.kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
attn_metadata.alibi_slopes,
|
||||
attn_metadata.sm_scale,
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
return context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=False,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
|
||||
return flash_decoding_attention(
|
||||
q=attn_metadata.query_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
kv_seq_len=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
max_seq_len_in_batch=attn_metadata.kv_seq_len,
|
||||
output=attn_metadata.output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
kv_group_num=kwargs.get('num_key_value_groups', 0),
|
||||
q_len=kwargs.get('q_len', 1),
|
||||
)
|
||||
|
||||
|
||||
def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend:
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaAttentionBackend()
|
||||
else:
|
||||
return TritonAttentionBackend()
|
||||
|
Reference in New Issue
Block a user