mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
@@ -19,7 +19,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
|
||||
from colossalai.inference.modeling.backends.attention_context import get_attention_context
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
@@ -121,7 +121,7 @@ def llama_model_forward(
|
||||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if inputmetadata.dtype != torch.float32 and can_use_flash_attn2():
|
||||
if can_use_flash_attn2(inputmetadata.dtype):
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
@@ -544,13 +544,14 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=False,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
attention_context = get_attention_context(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
attention_context.prefill(
|
||||
pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
@@ -563,7 +564,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
attention_context.decode(
|
||||
pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
|
Reference in New Issue
Block a user