mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 20:39:48 +00:00
fix test bugs
This commit is contained in:
parent
dec8fdb229
commit
719fb81667
@ -450,12 +450,10 @@ class InferenceEngine:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
with torch.inference_mode():
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
|
||||
if prompts is not None or prompts_token_ids is not None:
|
||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
if isinstance(request_ids, int):
|
||||
request_ids = [request_ids]
|
||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||
|
||||
@ -463,6 +461,8 @@ class InferenceEngine:
|
||||
total_tokens_list = []
|
||||
|
||||
# intuition: If user provide a generation config, we should replace the existing one.
|
||||
if generation_config is not None:
|
||||
self.generation_config = generation_config
|
||||
|
||||
if self.use_spec_dec:
|
||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||
@ -538,7 +538,7 @@ class InferenceEngine:
|
||||
|
||||
if isinstance(prompts_token_ids, list):
|
||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
|
||||
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||
prompts_token_ids = prompts_token_ids.tolist()
|
||||
else:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from colossalai.inference.core.async_engine import RequestTracker
|
||||
from colossalai.inference.core.async_engine import Tracer
|
||||
from colossalai.inference.struct import Sequence
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ class SampleEvent:
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
tracker = Tracer()
|
||||
tracker.new_requests_event = SampleEvent()
|
||||
stream_1 = tracker.add_request(1)
|
||||
assert tracker.new_requests_event.flag
|
||||
|
@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
|
||||
|
||||
|
||||
@parameterize(
|
||||
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
"max_batch_size": 8,
|
||||
"max_output_len": 512,
|
||||
"max_input_len": 64,
|
||||
"do_sample": False,
|
||||
}
|
||||
],
|
||||
)
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
|
||||
setup_seed(20)
|
||||
max_batch_size = test_config["max_batch_size"]
|
||||
max_input_len = test_config["max_input_len"]
|
||||
max_output_len = test_config["max_output_len"]
|
||||
do_sample = test_config["do_sample"]
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||
model = model.eval()
|
||||
|
@ -36,17 +36,27 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
]
|
||||
|
||||
output_len = 38
|
||||
do_sample = True
|
||||
do_sample = False
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||
inference_config = InferenceConfig(
|
||||
max_output_len=output_len,
|
||||
prompt_template=prompt_template,
|
||||
dtype="fp32",
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
use_cuda_kernel=True,
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
|
||||
)
|
||||
outputs = inference_engine.generate(generation_config=generation_config)
|
||||
else:
|
||||
if prompt_template:
|
||||
@ -58,6 +68,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
inputs = inputs.cuda()
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=do_sample,
|
||||
dtype="fp32",
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
|
Loading…
Reference in New Issue
Block a user