mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-07 11:03:58 +00:00
[Inference]Optimize generation process of inference engine (#5356)
* opt inference engine * fix run_benchmark.sh * fix generate in engine.py * rollback tesh_inference_engine.py
This commit is contained in:
parent
21ad4a27f9
commit
631862f339
@ -134,12 +134,16 @@ class InferenceEngine:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
prompts: List[str] = None,
|
||||||
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
generation_config: GenerationConfig = None,
|
generation_config: GenerationConfig = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Executing the inference step.
|
Executing the inference step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||||
|
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -147,13 +151,23 @@ class InferenceEngine:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.generation_config = generation_config
|
self.generation_config = generation_config
|
||||||
|
if prompts is not None or prompts_token_ids is not None:
|
||||||
|
self.add_request(prompts=prompts, prompts_token_ids=prompts_token_ids)
|
||||||
|
|
||||||
output_list = []
|
output_seqs_list = []
|
||||||
|
output_tokens_list = []
|
||||||
|
|
||||||
while self.request_handler.check_unfinished_seqs():
|
while self.request_handler.check_unfinished_seqs():
|
||||||
output_list += self.step()
|
output_seqs_list += self.step()
|
||||||
|
|
||||||
return output_list
|
output_seqs_list = sorted(output_seqs_list, key=lambda x: int(x.request_id))
|
||||||
|
|
||||||
|
for seq in output_seqs_list:
|
||||||
|
output_tokens_list.append(seq.input_token_id + seq.output_token_id)
|
||||||
|
|
||||||
|
output_str = self.tokenizer.batch_decode(output_tokens_list, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return output_str
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
@ -235,7 +249,6 @@ class InferenceEngine:
|
|||||||
List[str]: Decoded finished sequences generated by one step.
|
List[str]: Decoded finished sequences generated by one step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_list = []
|
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
|
||||||
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
# TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported.
|
||||||
@ -251,10 +264,4 @@ class InferenceEngine:
|
|||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
self.request_handler.search_tokens(self.generation_config, logits)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
# Decode completed sentences.
|
return finished_sequences
|
||||||
# TODO : update decoding step
|
|
||||||
for seq in finished_sequences:
|
|
||||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
|
||||||
output_list.append(output_str)
|
|
||||||
|
|
||||||
return output_list
|
|
||||||
|
@ -141,8 +141,7 @@ def benchmark_inference(args):
|
|||||||
with ctx:
|
with ctx:
|
||||||
for _ in range(N_WARMUP_STEPS):
|
for _ in range(N_WARMUP_STEPS):
|
||||||
if args.mode == "caiinference":
|
if args.mode == "caiinference":
|
||||||
engine.add_request(prompts_token_ids=data)
|
engine.generate(prompts_token_ids=data, generation_config=generation_config)
|
||||||
engine.generate(generation_config)
|
|
||||||
else:
|
else:
|
||||||
engine.generate(data, generation_config=generation_config)
|
engine.generate(data, generation_config=generation_config)
|
||||||
if args.profile:
|
if args.profile:
|
||||||
@ -156,8 +155,7 @@ def benchmark_inference(args):
|
|||||||
whole_end2end = time.perf_counter()
|
whole_end2end = time.perf_counter()
|
||||||
if args.mode == "caiinference":
|
if args.mode == "caiinference":
|
||||||
for _ in range(args.batch_size // mbsz):
|
for _ in range(args.batch_size // mbsz):
|
||||||
engine.add_request(prompts_token_ids=data)
|
engine.generate(prompts_token_ids=data, generation_config=generation_config)
|
||||||
engine.generate(generation_config)
|
|
||||||
else:
|
else:
|
||||||
for _ in range(args.batch_size // mbsz):
|
for _ in range(args.batch_size // mbsz):
|
||||||
engine.generate(data, generation_config=generation_config)
|
engine.generate(data, generation_config=generation_config)
|
||||||
|
@ -49,7 +49,7 @@ def check_inference_engine(test_cai=False):
|
|||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
assert inference_engine.request_handler._has_waiting()
|
assert inference_engine.request_handler._has_waiting()
|
||||||
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
|
||||||
outputs = inference_engine.generate(generation_config)
|
outputs = inference_engine.generate(generation_config=generation_config)
|
||||||
else:
|
else:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
Loading…
Reference in New Issue
Block a user