Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-07 08:28:19 +00:00
parent eec77e5702
commit 5f398fc000
11 changed files with 238 additions and 136 deletions

View File

@@ -1,19 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from flash_attn import flash_attn_varlen_func
import torch
from flash_attn import flash_attn_varlen_func
from colossalai.inference.config import InputMetaData
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.logging import get_dist_logger
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
flash_decoding_attention,
)
logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
from colossalai.kernel.triton import context_attention_unpadded, flash_decoding_attention
@dataclass
@@ -33,7 +26,6 @@ class AttentionMetaData:
output_tensor: torch.Tensor = None
use_spec_dec: bool = False
use_alibi_attn: bool = False
use_cuda_kernel: bool = False
class AttentionBackend(ABC):
@@ -46,7 +38,16 @@ class AttentionBackend(ABC):
raise NotImplementedError
class CudaAttentionBackend(AttentionBackend):
class FlashAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True and flash-attn is installed. It uses
`flash_attn_varlen_func` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
token_nums = kwargs.get("token_nums", -1)
@@ -69,7 +70,55 @@ class CudaAttentionBackend(AttentionBackend):
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
inference_ops.flash_decoding_attention(
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
attn_metadata.block_size,
attn_metadata.kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
attn_metadata.alibi_slopes,
attn_metadata.sm_scale,
)
return output_tensor
class CudaAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is True but flash-attn not found. If flash-attn is not found,
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
k=attn_metadata.key_states,
v=attn_metadata.value_states,
k_cache=attn_metadata.k_cache,
v_cache=attn_metadata.v_cache,
context_lengths=attn_metadata.sequence_lengths,
block_tables=attn_metadata.block_tables,
block_size=attn_metadata.block_size,
output=attn_metadata.output_tensor,
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=True, # use new k cache layout for cuda kernels in this triton op
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
output_tensor = attn_metadata.output_tensor
self.inference_ops.flash_decoding_attention(
output_tensor,
attn_metadata.query_states,
attn_metadata.k_cache,
@@ -88,6 +137,10 @@ class CudaAttentionBackend(AttentionBackend):
class TritonAttentionBackend(AttentionBackend):
"""
Attention backend when use_cuda_kernel is False. It uses pure Triton ops for prefilling and decoding.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
return context_attention_unpadded(
q=attn_metadata.query_states,
@@ -102,7 +155,7 @@ class TritonAttentionBackend(AttentionBackend):
alibi_slopes=attn_metadata.alibi_slopes,
max_seq_len=attn_metadata.kv_seq_len,
sm_scale=attn_metadata.sm_scale,
use_new_kcache_layout=attn_metadata.use_cuda_kernel,
use_new_kcache_layout=False,
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
@@ -126,17 +179,24 @@ class TritonAttentionBackend(AttentionBackend):
def get_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
model_shard_infer_config: ModelShardInferenceConfig,
) -> AttentionBackend:
"""
Get the attention backend based on the inference configurations. Only when:
Get the attention backend based on the inference configurations. The modeling will use CUDA-kernel-based backend
for attention module calculation only when:
1. using CUDA kernel (use_cuda_kernel=True)
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend.
Otherwise, use Triton attention backend. If found flash-attn not installed while `use_cuda_kernel` is True,
the Triton backend will use a new k cache layout for Triton kernels.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaAttentionBackend()
else:
# Currently only triton kernels support speculative decoding
if model_shard_infer_config.use_spec_dec:
return TritonAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
if model_shard_infer_config.use_flash_attn:
return FlashAttentionBackend()
return CudaAttentionBackend()
return TritonAttentionBackend()