mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
fix(scene): fix chat scene config custom prompt (#2725)
This commit is contained in:
parent
22daf08ac5
commit
fe98239d3d
@ -172,7 +172,18 @@ class BaseChat(ABC):
|
|||||||
prompt_template = self._prompt_service.get_template(self.prompt_code)
|
prompt_template = self._prompt_service.get_template(self.prompt_code)
|
||||||
chat_prompt_template = ChatPromptTemplate(
|
chat_prompt_template = ChatPromptTemplate(
|
||||||
messages=[
|
messages=[
|
||||||
SystemPromptTemplate.from_template(prompt_template.template),
|
SystemPromptTemplate.from_template(
|
||||||
|
template=prompt_template.template,
|
||||||
|
template_format=prompt_template.template_format,
|
||||||
|
response_format=(
|
||||||
|
prompt_template.response_format
|
||||||
|
if prompt_template.response_format
|
||||||
|
and prompt_template.response_format != "{}"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
response_key=prompt_template.response_key,
|
||||||
|
template_is_strict=prompt_template.template_is_strict,
|
||||||
|
),
|
||||||
MessagesPlaceholder(variable_name="chat_history"),
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
HumanPromptTemplate.from_template("{question}"),
|
HumanPromptTemplate.from_template("{question}"),
|
||||||
]
|
]
|
||||||
@ -212,6 +223,30 @@ class BaseChat(ABC):
|
|||||||
"""
|
"""
|
||||||
return self.parse_user_input()
|
return self.parse_user_input()
|
||||||
|
|
||||||
|
async def prepare_input_values(self) -> Dict:
|
||||||
|
"""Generate input value compatible with custom prompt to LLM
|
||||||
|
|
||||||
|
Please note that you must not perform any blocking operations in this function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dictionary to be formatted by prompt template
|
||||||
|
"""
|
||||||
|
input_values = await self.generate_input_values()
|
||||||
|
|
||||||
|
# Mapping variable names: compatible with custom prompt template variable names
|
||||||
|
# Get the input_variables of the current prompt
|
||||||
|
input_variables = []
|
||||||
|
if hasattr(self.prompt_template, "prompt") and hasattr(
|
||||||
|
self.prompt_template.prompt, "input_variables"
|
||||||
|
):
|
||||||
|
input_variables = self.prompt_template.prompt.input_variables
|
||||||
|
# Compatible with question and user_input
|
||||||
|
if "question" in input_variables and "question" not in input_values:
|
||||||
|
input_values["question"] = self.current_user_input
|
||||||
|
if "user_input" in input_variables and "user_input" not in input_values:
|
||||||
|
input_values["user_input"] = self.current_user_input
|
||||||
|
return input_values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def llm_client(self) -> LLMClient:
|
def llm_client(self) -> LLMClient:
|
||||||
"""Return the LLM client."""
|
"""Return the LLM client."""
|
||||||
@ -313,7 +348,7 @@ class BaseChat(ABC):
|
|||||||
return user_params
|
return user_params
|
||||||
|
|
||||||
async def _build_model_request(self) -> ModelRequest:
|
async def _build_model_request(self) -> ModelRequest:
|
||||||
input_values = await self.generate_input_values()
|
input_values = await self.prepare_input_values()
|
||||||
# Load history
|
# Load history
|
||||||
self.history_messages = self.current_message.get_history_message()
|
self.history_messages = self.current_message.get_history_message()
|
||||||
self.current_message.start_new_round()
|
self.current_message.start_new_round()
|
||||||
|
@ -95,19 +95,6 @@ class ChatDashboard(BaseChat):
|
|||||||
"supported_chat_type": self.dashboard_template["supported_chart_type"],
|
"supported_chat_type": self.dashboard_template["supported_chart_type"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mapping variable names: compatible with custom prompt template variable names
|
|
||||||
# Get the input_variables of the current prompt
|
|
||||||
input_variables = []
|
|
||||||
if hasattr(self.prompt_template, "prompt") and hasattr(
|
|
||||||
self.prompt_template.prompt, "input_variables"
|
|
||||||
):
|
|
||||||
input_variables = self.prompt_template.prompt.input_variables
|
|
||||||
# Compatible with question and user_input
|
|
||||||
if "question" in input_variables:
|
|
||||||
input_values["question"] = self.current_user_input
|
|
||||||
if "user_input" in input_variables:
|
|
||||||
input_values["user_input"] = self.current_user_input
|
|
||||||
|
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
|
@ -139,6 +139,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
user_code=entity.user_code,
|
user_code=entity.user_code,
|
||||||
model=entity.model,
|
model=entity.model,
|
||||||
input_variables=entity.input_variables,
|
input_variables=entity.input_variables,
|
||||||
|
response_schema=entity.response_schema,
|
||||||
prompt_language=entity.prompt_language,
|
prompt_language=entity.prompt_language,
|
||||||
sys_code=entity.sys_code,
|
sys_code=entity.sys_code,
|
||||||
gmt_created=gmt_created_str,
|
gmt_created=gmt_created_str,
|
||||||
|
Loading…
Reference in New Issue
Block a user