mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference] Fix flash-attn import and add model test (#5794)
* Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
@@ -44,7 +43,7 @@ class CudaAttentionBackend(AttentionBackend):
|
||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||
"""
|
||||
|
||||
def __init__(self, use_flash_attn: bool):
|
||||
def __init__(self, use_flash_attn: bool = False):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
@@ -52,6 +51,9 @@ class CudaAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if self.use_flash_attn:
|
||||
token_nums = kwargs.get("token_nums", -1)
|
||||
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
|
@@ -200,8 +200,6 @@ class NopadBaichuanAttention(ParallelModule):
|
||||
|
||||
self.pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = self.attention_backend.decode(
|
||||
|
@@ -114,7 +114,7 @@ def llama_model_forward(
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if can_use_flash_attn2(inputmetadata.dtype):
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
total_length = hidden_states.size(0)
|
||||
@@ -265,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
|
||||
mlp_dproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
):
|
||||
"""A Unified Layer for
|
||||
"""Replacement of LlamaMLP layer.
|
||||
|
||||
Args:
|
||||
config (LlamaConfig): Holding the Llama model config.
|
||||
|
Reference in New Issue
Block a user