diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py index 1f80a8121..236f68160 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py @@ -172,7 +172,18 @@ class BaseChat(ABC): prompt_template = self._prompt_service.get_template(self.prompt_code) chat_prompt_template = ChatPromptTemplate( messages=[ - SystemPromptTemplate.from_template(prompt_template.template), + SystemPromptTemplate.from_template( + template=prompt_template.template, + template_format=prompt_template.template_format, + response_format=( + prompt_template.response_format + if prompt_template.response_format + and prompt_template.response_format != "{}" + else None + ), + response_key=prompt_template.response_key, + template_is_strict=prompt_template.template_is_strict, + ), MessagesPlaceholder(variable_name="chat_history"), HumanPromptTemplate.from_template("{question}"), ] @@ -212,6 +223,30 @@ class BaseChat(ABC): """ return self.parse_user_input() + async def prepare_input_values(self) -> Dict: + """Generate input value compatible with custom prompt to LLM + + Please note that you must not perform any blocking operations in this function + + Returns: + a dictionary to be formatted by prompt template + """ + input_values = await self.generate_input_values() + + # Mapping variable names: compatible with custom prompt template variable names + # Get the input_variables of the current prompt + input_variables = [] + if hasattr(self.prompt_template, "prompt") and hasattr( + self.prompt_template.prompt, "input_variables" + ): + input_variables = self.prompt_template.prompt.input_variables + # Compatible with question and user_input + if "question" in input_variables and "question" not in input_values: + input_values["question"] = self.current_user_input + if "user_input" in input_variables and "user_input" not in input_values: + input_values["user_input"] = self.current_user_input + return input_values + @property def llm_client(self) -> LLMClient: """Return the LLM client.""" @@ -313,7 +348,7 @@ class BaseChat(ABC): return user_params async def _build_model_request(self) -> ModelRequest: - input_values = await self.generate_input_values() + input_values = await self.prepare_input_values() # Load history self.history_messages = self.current_message.get_history_message() self.current_message.start_new_round() diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py index ed3b11383..23e6d8039 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py @@ -95,19 +95,6 @@ class ChatDashboard(BaseChat): "supported_chat_type": self.dashboard_template["supported_chart_type"], } - # Mapping variable names: compatible with custom prompt template variable names - # Get the input_variables of the current prompt - input_variables = [] - if hasattr(self.prompt_template, "prompt") and hasattr( - self.prompt_template.prompt, "input_variables" - ): - input_variables = self.prompt_template.prompt.input_variables - # Compatible with question and user_input - if "question" in input_variables: - input_values["question"] = self.current_user_input - if "user_input" in input_variables: - input_values["user_input"] = self.current_user_input - return input_values def do_action(self, prompt_response): diff --git a/packages/dbgpt-serve/src/dbgpt_serve/prompt/models/models.py b/packages/dbgpt-serve/src/dbgpt_serve/prompt/models/models.py index 9f6122846..2efed1c55 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/prompt/models/models.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/prompt/models/models.py @@ -139,6 +139,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): user_code=entity.user_code, model=entity.model, input_variables=entity.input_variables, + response_schema=entity.response_schema, prompt_language=entity.prompt_language, sys_code=entity.sys_code, gmt_created=gmt_created_str,