mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[inference] added inference template (#5375)
This commit is contained in:
@@ -170,6 +170,26 @@ class InferenceEngine:
|
||||
|
||||
return output_str
|
||||
|
||||
@property
|
||||
def has_prompt_template(self) -> bool:
|
||||
""" """
|
||||
return self.inference_config.prompt_template is not None
|
||||
|
||||
def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]:
|
||||
"""
|
||||
This method will format the input prompt according to the prompt template given to the InferenceConfig.
|
||||
"""
|
||||
assert (
|
||||
self.has_prompt_template
|
||||
), "Found the prompt_template is None. Please provide a valid prompt_template in InferenceConfig."
|
||||
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
|
||||
elif isinstance(prompts, str):
|
||||
return self.inference_config.rompt_template.format(input_text=prompts)
|
||||
else:
|
||||
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
requests_id: List[int] = None,
|
||||
@@ -185,6 +205,10 @@ class InferenceEngine:
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
|
||||
# apply the prompt template to the input prompts
|
||||
if self.has_prompt_template and prompts is not None:
|
||||
prompts = self.format_prompt(prompts)
|
||||
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
if prompts_token_ids is None:
|
||||
|
Reference in New Issue
Block a user