[Inference] Fix flash-attn import and add model test (#5794)

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
Li Xingjian
2024-06-12 14:13:50 +08:00
committed by GitHub
parent aac941ef78
commit 8554585a5f
7 changed files with 171 additions and 8 deletions

View File

@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from flash_attn import flash_attn_varlen_func
from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
@@ -44,7 +43,7 @@ class CudaAttentionBackend(AttentionBackend):
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""
def __init__(self, use_flash_attn: bool):
def __init__(self, use_flash_attn: bool = False):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn
@@ -52,6 +51,9 @@ class CudaAttentionBackend(AttentionBackend):
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)
from flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,