mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
Remove flash attention backend
Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -16,71 +16,37 @@ class PreAttentionBackend(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FlashPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
FlashPreAttentionBackend handles KV cache initialization and positional encoding for FlashAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
self.inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
self.inference_ops.decode_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
)
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
"""
|
||||
CudaPreAttentionBackend handles KV cache initialization and positional encoding for CudaAttentionBackend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, use_flash_attn: bool):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
if self.use_flash_attn:
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
self.inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
self.inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
elif not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
@@ -175,8 +141,6 @@ def get_pre_attention_backend(
|
||||
return TritonPreAttentionBackend()
|
||||
|
||||
if model_shard_infer_config.use_cuda_kernel:
|
||||
if model_shard_infer_config.use_flash_attn:
|
||||
return FlashPreAttentionBackend()
|
||||
return CudaPreAttentionBackend()
|
||||
return CudaPreAttentionBackend(model_shard_infer_config.use_flash_attn)
|
||||
|
||||
return TritonPreAttentionBackend()
|
||||
|
Reference in New Issue
Block a user