[inference] added inference template (#5375)

This commit is contained in:
Frank Lee
2024-02-07 17:11:43 +08:00
committed by GitHub
parent 8106ede07f
commit 58740b5f68
3 changed files with 65 additions and 9 deletions

View File

@@ -170,6 +170,26 @@ class InferenceEngine:
return output_str
@property
def has_prompt_template(self) -> bool:
""" """
return self.inference_config.prompt_template is not None
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
"""
This method will format the input prompt according to the prompt template given to the InferenceConfig.
"""
assert (
self.has_prompt_template
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.rompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
def add_request(
self,
requests_id: List[int] = None,
@@ -185,6 +205,10 @@ class InferenceEngine:
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""
# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
block_size = self.inference_config.block_size
if prompts_token_ids is None: