mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference] Fix flash-attn import and add model test (#5794)
* Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
from colossalai.inference.config import ModelShardInferenceConfig
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
@@ -44,7 +43,7 @@ class CudaAttentionBackend(AttentionBackend):
|
||||
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
|
||||
"""
|
||||
|
||||
def __init__(self, use_flash_attn: bool):
|
||||
def __init__(self, use_flash_attn: bool = False):
|
||||
super().__init__()
|
||||
self.inference_ops = InferenceOpsLoader().load()
|
||||
self.use_flash_attn = use_flash_attn
|
||||
@@ -52,6 +51,9 @@ class CudaAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if self.use_flash_attn:
|
||||
token_nums = kwargs.get("token_nums", -1)
|
||||
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
|
Reference in New Issue
Block a user