Pass inference model shard configs for module init

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
char-1ee
2024-06-07 08:28:19 +00:00
parent eec77e5702
commit 5f398fc000
11 changed files with 238 additions and 136 deletions

View File

@@ -1,18 +1,9 @@
from abc import ABC, abstractmethod
import torch
from colossalai.inference.utils import can_use_flash_attn2
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
from colossalai.logging import get_dist_logger
from colossalai.kernel.triton import (
copy_k_to_blocked_cache,
decoding_fused_rotary_embedding,
rotary_embedding,
)
logger = get_dist_logger(__name__)
inference_ops = InferenceOpsLoader().load()
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
class PreAttentionBackend(ABC):
@@ -25,17 +16,25 @@ class PreAttentionBackend(ABC):
raise NotImplementedError
class CudaPreAttentionBackend(PreAttentionBackend):
class FlashPreAttentionBackend(PreAttentionBackend):
"""
FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding(
self.inference_ops.rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
kwargs.get("high_precision", False),
)
inference_ops.context_kv_cache_memcpy(
self.inference_ops.context_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
@@ -48,7 +47,7 @@ class CudaPreAttentionBackend(PreAttentionBackend):
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
inference_ops.rotary_embedding_and_cache_copy(
self.inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
@@ -61,7 +60,50 @@ class CudaPreAttentionBackend(PreAttentionBackend):
kwargs.get("high_precision", None),
)
else:
inference_ops.decode_kv_cache_memcpy(
self.inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
)
class CudaPreAttentionBackend(PreAttentionBackend):
"""
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
"""
def __init__(self):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
attn_metadata.key_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
)
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
self.inference_ops.rotary_embedding_and_cache_copy(
attn_metadata.query_states,
attn_metadata.key_states,
attn_metadata.value_states,
kwargs.get("cos", None),
kwargs.get("sin", None),
attn_metadata.k_cache,
attn_metadata.v_cache,
attn_metadata.sequence_lengths,
attn_metadata.block_tables,
kwargs.get("high_precision", None),
)
else:
self.inference_ops.decode_kv_cache_memcpy(
attn_metadata.key_states,
attn_metadata.value_states,
attn_metadata.k_cache,
@@ -72,6 +114,10 @@ class CudaPreAttentionBackend(PreAttentionBackend):
class TritonPreAttentionBackend(PreAttentionBackend):
"""
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
"""
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if not attn_metadata.use_alibi_attn:
rotary_embedding(
@@ -94,7 +140,7 @@ class TritonPreAttentionBackend(PreAttentionBackend):
attn_metadata.block_tables,
attn_metadata.sequence_lengths,
)
else: # else if using speculative decoding
else: # else if using speculative decoding
if not attn_metadata.use_alibi_attn:
rotary_embedding(
attn_metadata.query_states,
@@ -119,13 +165,18 @@ class TritonPreAttentionBackend(PreAttentionBackend):
def get_pre_attention_backend(
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
model_shard_infer_config: ModelShardInferenceConfig,
) -> PreAttentionBackend:
"""
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
Get the backend for pre-attention computations, including potisional encoding like
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
"""
use_flash_attn = can_use_flash_attn2(dtype)
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
return CudaPreAttentionBackend()
else:
if model_shard_infer_config.use_spec_dec:
return TritonPreAttentionBackend()
if model_shard_infer_config.use_cuda_kernel:
if model_shard_infer_config.use_flash_attn:
return FlashPreAttentionBackend()
return CudaPreAttentionBackend()
return TritonPreAttentionBackend()