From 719fb816677ffaeae2be7be10823d0e48b2b3ae9 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 15 Apr 2024 15:43:14 +0800 Subject: [PATCH] fix test bugs --- colossalai/inference/core/engine.py | 10 +++++----- .../test_async_engine/test_request_tracker.py | 4 ++-- tests/test_infer/test_continuous_batching.py | 18 ++++++++++++++++-- tests/test_infer/test_inference_engine.py | 17 ++++++++++++++--- 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 9e5228019..dcd79301d 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -450,12 +450,10 @@ class InferenceEngine: List[str]: Inference result returned by one generation. """ with torch.inference_mode(): - if generation_config is not None: - self.generation_config = generation_config - if prompts is not None or prompts_token_ids is not None: - if isinstance(prompts, str) and isinstance(request_ids, int): + if isinstance(prompts, str): prompts = [prompts] + if isinstance(request_ids, int): request_ids = [request_ids] self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) @@ -463,6 +461,8 @@ class InferenceEngine: total_tokens_list = [] # intuition: If user provide a generation config, we should replace the existing one. + if generation_config is not None: + self.generation_config = generation_config if self.use_spec_dec: assert self.drafter is not None, "Drafter Model is not initialized." @@ -538,7 +538,7 @@ class InferenceEngine: if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids[0], torch.Tensor): - prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids] + prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids] elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): prompts_token_ids = prompts_token_ids.tolist() else: diff --git a/tests/test_infer/test_async_engine/test_request_tracker.py b/tests/test_infer/test_async_engine/test_request_tracker.py index 9a797a862..4b15d46c1 100644 --- a/tests/test_infer/test_async_engine/test_request_tracker.py +++ b/tests/test_infer/test_async_engine/test_request_tracker.py @@ -1,6 +1,6 @@ import pytest -from colossalai.inference.core.async_engine import RequestTracker +from colossalai.inference.core.async_engine import Tracer from colossalai.inference.struct import Sequence @@ -16,7 +16,7 @@ class SampleEvent: def test_request_tracker(): - tracker = RequestTracker() + tracker = Tracer() tracker.new_requests_event = SampleEvent() stream_1 = tracker.add_request(1) assert tracker.new_requests_event.flag diff --git a/tests/test_infer/test_continuous_batching.py b/tests/test_infer/test_continuous_batching.py index 0b0d92c7c..350ed473e 100644 --- a/tests/test_infer/test_continuous_batching.py +++ b/tests/test_infer/test_continuous_batching.py @@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length): @parameterize( - "max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50 + "test_config", + [ + { + "max_batch_size": 8, + "max_output_len": 512, + "max_input_len": 64, + "do_sample": False, + } + ], ) -def check_inference_engine(use_engine=False, prompt_template=None): +def check_inference_engine(test_config, use_engine=False, prompt_template=None): setup_seed(20) + max_batch_size = test_config["max_batch_size"] + max_input_len = test_config["max_input_len"] + max_output_len = test_config["max_output_len"] + do_sample = test_config["do_sample"] + top_p = 0.5 + top_k = 50 tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() model = model.eval() diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 088b1f5aa..6c7604629 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -36,17 +36,27 @@ def check_inference_engine(use_engine=False, prompt_template=None): ] output_len = 38 - do_sample = True + do_sample = False top_p = 0.5 top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") + inference_config = InferenceConfig( + max_output_len=output_len, + prompt_template=prompt_template, + dtype="fp32", + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + use_cuda_kernel=True, + ) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) + generation_config = GenerationConfig( + max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k + ) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: @@ -58,6 +68,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): inputs = inputs.cuda() generation_config = GenerationConfig( do_sample=do_sample, + dtype="fp32", top_p=top_p, top_k=top_k, pad_token_id=tokenizer.pad_token_id,