mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 20:39:48 +00:00
fix test bugs
This commit is contained in:
parent
dec8fdb229
commit
719fb81667
@ -450,12 +450,10 @@ class InferenceEngine:
|
|||||||
List[str]: Inference result returned by one generation.
|
List[str]: Inference result returned by one generation.
|
||||||
"""
|
"""
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
if generation_config is not None:
|
|
||||||
self.generation_config = generation_config
|
|
||||||
|
|
||||||
if prompts is not None or prompts_token_ids is not None:
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
if isinstance(prompts, str) and isinstance(request_ids, int):
|
if isinstance(prompts, str):
|
||||||
prompts = [prompts]
|
prompts = [prompts]
|
||||||
|
if isinstance(request_ids, int):
|
||||||
request_ids = [request_ids]
|
request_ids = [request_ids]
|
||||||
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
self.add_request(request_ids=request_ids, prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||||
|
|
||||||
@ -463,6 +461,8 @@ class InferenceEngine:
|
|||||||
total_tokens_list = []
|
total_tokens_list = []
|
||||||
|
|
||||||
# intuition: If user provide a generation config, we should replace the existing one.
|
# intuition: If user provide a generation config, we should replace the existing one.
|
||||||
|
if generation_config is not None:
|
||||||
|
self.generation_config = generation_config
|
||||||
|
|
||||||
if self.use_spec_dec:
|
if self.use_spec_dec:
|
||||||
assert self.drafter is not None, "Drafter Model is not initialized."
|
assert self.drafter is not None, "Drafter Model is not initialized."
|
||||||
@ -538,7 +538,7 @@ class InferenceEngine:
|
|||||||
|
|
||||||
if isinstance(prompts_token_ids, list):
|
if isinstance(prompts_token_ids, list):
|
||||||
if isinstance(prompts_token_ids[0], torch.Tensor):
|
if isinstance(prompts_token_ids[0], torch.Tensor):
|
||||||
prompts_token_ids = [prompt_token_ids.tolist() for prompt_token_ids in prompts_token_ids]
|
prompts_token_ids = [prompt_token_id.tolist() for prompt_token_id in prompts_token_ids]
|
||||||
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
elif isinstance(prompts_token_ids, torch.Tensor) or isinstance(prompts_token_ids, np.ndarray):
|
||||||
prompts_token_ids = prompts_token_ids.tolist()
|
prompts_token_ids = prompts_token_ids.tolist()
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from colossalai.inference.core.async_engine import RequestTracker
|
from colossalai.inference.core.async_engine import Tracer
|
||||||
from colossalai.inference.struct import Sequence
|
from colossalai.inference.struct import Sequence
|
||||||
|
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ class SampleEvent:
|
|||||||
|
|
||||||
|
|
||||||
def test_request_tracker():
|
def test_request_tracker():
|
||||||
tracker = RequestTracker()
|
tracker = Tracer()
|
||||||
tracker.new_requests_event = SampleEvent()
|
tracker.new_requests_event = SampleEvent()
|
||||||
stream_1 = tracker.add_request(1)
|
stream_1 = tracker.add_request(1)
|
||||||
assert tracker.new_requests_event.flag
|
assert tracker.new_requests_event.flag
|
||||||
|
@ -29,10 +29,24 @@ def generate_inputs(num_sequences, min_length, max_length):
|
|||||||
|
|
||||||
|
|
||||||
@parameterize(
|
@parameterize(
|
||||||
"max_batch_size", 8, "max_output_len", 512, "max_input_len", 64, "do_sample", True, "top_p", 0.5, "top_k", 50
|
"test_config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"max_batch_size": 8,
|
||||||
|
"max_output_len": 512,
|
||||||
|
"max_input_len": 64,
|
||||||
|
"do_sample": False,
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
|
||||||
setup_seed(20)
|
setup_seed(20)
|
||||||
|
max_batch_size = test_config["max_batch_size"]
|
||||||
|
max_input_len = test_config["max_input_len"]
|
||||||
|
max_output_len = test_config["max_output_len"]
|
||||||
|
do_sample = test_config["do_sample"]
|
||||||
|
top_p = 0.5
|
||||||
|
top_k = 50
|
||||||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||||
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
@ -36,17 +36,27 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||||||
]
|
]
|
||||||
|
|
||||||
output_len = 38
|
output_len = 38
|
||||||
do_sample = True
|
do_sample = False
|
||||||
top_p = 0.5
|
top_p = 0.5
|
||||||
top_k = 50
|
top_k = 50
|
||||||
|
|
||||||
if use_engine:
|
if use_engine:
|
||||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
inference_config = InferenceConfig(
|
||||||
|
max_output_len=output_len,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
dtype="fp32",
|
||||||
|
do_sample=do_sample,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
use_cuda_kernel=True,
|
||||||
|
)
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
|
||||||
|
)
|
||||||
outputs = inference_engine.generate(generation_config=generation_config)
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
else:
|
else:
|
||||||
if prompt_template:
|
if prompt_template:
|
||||||
@ -58,6 +68,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||||||
inputs = inputs.cuda()
|
inputs = inputs.cuda()
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
do_sample=do_sample,
|
do_sample=do_sample,
|
||||||
|
dtype="fp32",
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user