fix(scene): fix chat scene config custom prompt (#2725)

This commit is contained in:
alanchen 2025-05-23 17:05:56 +08:00 committed by GitHub
parent 22daf08ac5
commit fe98239d3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 15 deletions

View File

@ -172,7 +172,18 @@ class BaseChat(ABC):
prompt_template = self._prompt_service.get_template(self.prompt_code)
chat_prompt_template = ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(prompt_template.template),
SystemPromptTemplate.from_template(
template=prompt_template.template,
template_format=prompt_template.template_format,
response_format=(
prompt_template.response_format
if prompt_template.response_format
and prompt_template.response_format != "{}"
else None
),
response_key=prompt_template.response_key,
template_is_strict=prompt_template.template_is_strict,
),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{question}"),
]
@ -212,6 +223,30 @@ class BaseChat(ABC):
"""
return self.parse_user_input()
async def prepare_input_values(self) -> Dict:
"""Generate input value compatible with custom prompt to LLM
Please note that you must not perform any blocking operations in this function
Returns:
a dictionary to be formatted by prompt template
"""
input_values = await self.generate_input_values()
# Mapping variable names: compatible with custom prompt template variable names
# Get the input_variables of the current prompt
input_variables = []
if hasattr(self.prompt_template, "prompt") and hasattr(
self.prompt_template.prompt, "input_variables"
):
input_variables = self.prompt_template.prompt.input_variables
# Compatible with question and user_input
if "question" in input_variables and "question" not in input_values:
input_values["question"] = self.current_user_input
if "user_input" in input_variables and "user_input" not in input_values:
input_values["user_input"] = self.current_user_input
return input_values
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
@ -313,7 +348,7 @@ class BaseChat(ABC):
return user_params
async def _build_model_request(self) -> ModelRequest:
input_values = await self.generate_input_values()
input_values = await self.prepare_input_values()
# Load history
self.history_messages = self.current_message.get_history_message()
self.current_message.start_new_round()

View File

@ -95,19 +95,6 @@ class ChatDashboard(BaseChat):
"supported_chat_type": self.dashboard_template["supported_chart_type"],
}
# Mapping variable names: compatible with custom prompt template variable names
# Get the input_variables of the current prompt
input_variables = []
if hasattr(self.prompt_template, "prompt") and hasattr(
self.prompt_template.prompt, "input_variables"
):
input_variables = self.prompt_template.prompt.input_variables
# Compatible with question and user_input
if "question" in input_variables:
input_values["question"] = self.current_user_input
if "user_input" in input_variables:
input_values["user_input"] = self.current_user_input
return input_values
def do_action(self, prompt_response):

View File

@ -139,6 +139,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
user_code=entity.user_code,
model=entity.model,
input_variables=entity.input_variables,
response_schema=entity.response_schema,
prompt_language=entity.prompt_language,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,