fix test bugs

This commit is contained in:
CjhHa1 2024-04-15 15:43:14 +08:00
parent dec8fdb229
commit 719fb81667
4 changed files with 37 additions and 12 deletions

View File

@ -450,12 +450,10 @@ class InferenceEngine:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
if generation_config is not None:
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None:
if isinstance(prompts, str) and isinstance(request_ids, int):
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(request_ids, int):
request_ids = [request_ids]
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
@ -463,6 +461,8 @@ class InferenceEngine:
total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized."
@ -538,7 +538,7 @@ class InferenceEngine:
if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist()
else:

View File

@ -1,6 +1,6 @@
import pytest
from colossalai.inference.core.async_engine import RequestTracker
from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence
@ -16,7 +16,7 @@ class SampleEvent:
def test_request_tracker():
tracker = RequestTracker()
tracker = Tracer()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag

View File

@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
@parameterize(
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
"test_config",
[
{
"max_batch_size": 8,
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
)
def check_inference_engine(use_engine=False, prompt_template=None):
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"]
max_output_len = test_config["max_output_len"]
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval()

View File

@ -36,17 +36,27 @@ def check_inference_engine(use_engine=False, prompt_template=None):
]
output_len = 38
do_sample = True
do_sample = False
top_p = 0.5
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
dtype="fp32",
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
use_cuda_kernel=True,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
@ -58,6 +68,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
dtype="fp32",
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,