mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
Pass inference model shard configs for module init
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -1,18 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.kernel.triton import (
|
||||
copy_k_to_blocked_cache,
|
||||
decoding_fused_rotary_embedding,
|
||||
rotary_embedding,
|
||||
)
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||
|
||||
|
||||
class PreAttentionBackend(ABC):
|
||||
@@ -25,17 +16,25 @@ class PreAttentionBackend(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
class FlashPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(
|
||||
self.inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
self.inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
@@ -48,7 +47,7 @@ class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
@@ -61,7 +60,50 @@ class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
self.inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
)
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
self.inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
@@ -72,6 +114,10 @@ class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
|
||||
|
||||
class TritonPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
TritonPreAttentionBackend handles KV cache initialization and positional encoding for TritonAttentionBackend.
|
||||
"""
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
@@ -94,7 +140,7 @@ class TritonPreAttentionBackend(PreAttentionBackend):
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.sequence_lengths,
|
||||
)
|
||||
else: # else if using speculative decoding
|
||||
else: # else if using speculative decoding
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
@@ -119,13 +165,18 @@ class TritonPreAttentionBackend(PreAttentionBackend):
|
||||
|
||||
|
||||
def get_pre_attention_backend(
|
||||
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
|
||||
model_shard_infer_config: ModelShardInferenceConfig,
|
||||
) -> PreAttentionBackend:
|
||||
"""
|
||||
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
|
||||
Get the backend for pre-attention computations, including potisional encoding like
|
||||
RoPE and KV cache initialization. It adopt the same selection logic as attention_backend/get_attention_backend.
|
||||
"""
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaPreAttentionBackend()
|
||||
else:
|
||||
if model_shard_infer_config.use_spec_dec:
|
||||
return TritonPreAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
if model_shard_infer_config.use_flash_attn:
|
||||
return FlashPreAttentionBackend()
|
||||
return CudaPreAttentionBackend()
|
||||
|
||||
return TritonPreAttentionBackend()
|
||||
|
Reference in New Issue
Block a user