From b303976a27f0058d3db9195b3c40e7f79514240f Mon Sep 17 00:00:00 2001 From: char-1ee Date: Mon, 10 Jun 2024 02:03:30 +0000 Subject: [PATCH] Fix test import Signed-off-by: char-1ee --- .../test_kernels/cuda/test_flash_decoding_attention.py | 2 +- tests/test_infer/test_kernels/triton/test_context_attn_unpad.py | 2 +- tests/test_infer/test_kernels/triton/test_decoding_attn.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py index 0bd398e2e..e9bf24d53 100644 --- a/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py +++ b/tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask diff --git a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py index 9d76858ed..92173ac13 100644 --- a/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_context_attn_unpad.py @@ -2,7 +2,7 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.kernel_utils import ( diff --git a/tests/test_infer/test_kernels/triton/test_decoding_attn.py b/tests/test_infer/test_kernels/triton/test_decoding_attn.py index 40a6eae58..aa2a7e2b4 100644 --- a/tests/test_infer/test_kernels/triton/test_decoding_attn.py +++ b/tests/test_infer/test_kernels/triton/test_decoding_attn.py @@ -3,7 +3,7 @@ import pytest import torch from packaging import version -from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes +from colossalai.inference.utils import get_alibi_slopes from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_kernels.triton.kernel_utils import (