mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Fix] Fix & Update Inference Tests (compatibility w/ main)
This commit is contained in:
@@ -80,7 +80,7 @@ def check_config_and_inference():
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_config_and_inference()
|
||||
|
||||
|
||||
|
@@ -80,7 +80,7 @@ def check_output_consistency(batch_size):
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_output_consistency(32)
|
||||
check_output_consistency(64)
|
||||
check_output_consistency(128)
|
||||
|
@@ -157,7 +157,7 @@ def check_spec_dec(num_layers, max_length):
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
|
@@ -7,11 +7,11 @@ import torch
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v3,
|
@@ -3,7 +3,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
|
||||
from tests.test_infer.test_kernels.triton.test_xine_copy import get_cos_sin
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
@@ -4,7 +4,10 @@ import torch.nn.functional as F
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v3,
|
||||
mock_alloc_single_token,
|
||||
)
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
@@ -7,8 +7,8 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
|
||||
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
|
||||
from tests.test_infer.test_kernels.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||
|
||||
|
||||
def numpy_allclose(x, y, rtol, atol):
|
@@ -5,7 +5,7 @@ from packaging import version
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.triton import context_attention_unpadded
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
@@ -6,14 +6,14 @@ from packaging import version
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
convert_kv_unpad_to_padded,
|
||||
create_attention_mask,
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
torch_attn_ref,
|
||||
)
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
from tests.test_infer.test_kernels.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
||||
try:
|
||||
import triton # noqa
|
@@ -4,7 +4,7 @@ from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
generate_caches_and_block_tables_v2,
|
||||
generate_caches_and_block_tables_v3,
|
||||
mock_alloc_single_token,
|
@@ -4,7 +4,7 @@ from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import decoding_fused_rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||
from tests.test_infer.test_kernels.triton.kernel_utils import (
|
||||
mock_alloc_block_table_and_kvcache_v2,
|
||||
mock_alloc_block_table_and_kvcache_v3,
|
||||
)
|
@@ -164,7 +164,7 @@ def check_cache_manager(test_config):
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_cache_manager()
|
||||
|
||||
|
||||
|
@@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
|
||||
|
||||
|
||||
@@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs):
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
@@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
@parameterize("prompt_template", [None, "baichuan"])
|
||||
@parameterize("do_sample", [False])
|
||||
@parameterize("use_cuda_kernel", [True])
|
||||
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
kwargs1 = {
|
||||
"use_engine": True,
|
||||
"prompt_template": prompt_template,
|
||||
@@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference_engine():
|
||||
test_tp_engine()
|
||||
check_tp_engine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -90,7 +90,7 @@ def check_request_handler():
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_running_list()
|
||||
check_request_handler()
|
||||
|
||||
|
Reference in New Issue
Block a user