fix test bugs

This commit is contained in:
CjhHa1 2024-04-15 15:43:14 +08:00
parent dec8fdb229
commit 719fb81667
4 changed files with 37 additions and 12 deletions

View File

@ -450,12 +450,10 @@ class InferenceEngine:
List[str]: Inference result returned by one generation. List[str]: Inference result returned by one generation.
""" """
with torch.inference_mode(): with torch.inference_mode():
if generation_config is not None:
self.generation_config = generation_config
if prompts is not None or prompts_token_ids is not None: if prompts is not None or prompts_token_ids is not None:
if isinstance(prompts, str) and isinstance(request_ids, int): if isinstance(prompts, str):
prompts = [prompts] prompts = [prompts]
if isinstance(request_ids, int):
request_ids = [request_ids] request_ids = [request_ids]
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids) self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
@ -463,6 +461,8 @@ class InferenceEngine:
total_tokens_list = [] total_tokens_list = []
# intuition: If user provide a generation config, we should replace the existing one. # intuition: If user provide a generation config, we should replace the existing one.
if generation_config is not None:
self.generation_config = generation_config
if self.use_spec_dec: if self.use_spec_dec:
assert self.drafter is not None, "Drafter Model is not initialized." assert self.drafter is not None, "Drafter Model is not initialized."
@ -538,7 +538,7 @@ class InferenceEngine:
if isinstance(prompts_token_ids, list): if isinstance(prompts_token_ids, list):
if isinstance(prompts_token_ids[0], torch.Tensor): if isinstance(prompts_token_ids[0], torch.Tensor):
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids] prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray): elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
prompts_token_ids = prompts_token_ids.tolist() prompts_token_ids = prompts_token_ids.tolist()
else: else:

View File

@ -1,6 +1,6 @@
import pytest import pytest
from colossalai.inference.core.async_engine import RequestTracker from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence from colossalai.inference.struct import Sequence
@ -16,7 +16,7 @@ class SampleEvent:
def test_request_tracker(): def test_request_tracker():
tracker = RequestTracker() tracker = Tracer()
tracker.new_requests_event = SampleEvent() tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1) stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag assert tracker.new_requests_event.flag

View File

@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
@parameterize( @parameterize(
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50 "test_config",
[
{
"max_batch_size": 8,
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
) )
def check_inference_engine(use_engine=False, prompt_template=None): def check_inference_engine(test_config, use_engine=False, prompt_template=None):
setup_seed(20) setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"]
max_output_len = test_config["max_output_len"]
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half() model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
model = model.eval() model = model.eval()

View File

@ -36,17 +36,27 @@ def check_inference_engine(use_engine=False, prompt_template=None):
] ]
output_len = 38 output_len = 38
do_sample = True do_sample = False
top_p = 0.5 top_p = 0.5
top_k = 50 top_k = 50
if use_engine: if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
dtype="fp32",
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
use_cuda_kernel=True,
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs) inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting() assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config) outputs = inference_engine.generate(generation_config=generation_config)
else: else:
if prompt_template: if prompt_template:
@ -58,6 +68,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
inputs = inputs.cuda() inputs = inputs.cuda()
generation_config = GenerationConfig( generation_config = GenerationConfig(
do_sample=do_sample, do_sample=do_sample,
dtype="fp32",
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,