mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 20:01:46 +00:00
fix(scene): fix chat scene config custom prompt (#2725)
This commit is contained in:
parent
22daf08ac5
commit
fe98239d3d
@ -172,7 +172,18 @@ class BaseChat(ABC):
|
||||
prompt_template = self._prompt_service.get_template(self.prompt_code)
|
||||
chat_prompt_template = ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(prompt_template.template),
|
||||
SystemPromptTemplate.from_template(
|
||||
template=prompt_template.template,
|
||||
template_format=prompt_template.template_format,
|
||||
response_format=(
|
||||
prompt_template.response_format
|
||||
if prompt_template.response_format
|
||||
and prompt_template.response_format != "{}"
|
||||
else None
|
||||
),
|
||||
response_key=prompt_template.response_key,
|
||||
template_is_strict=prompt_template.template_is_strict,
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanPromptTemplate.from_template("{question}"),
|
||||
]
|
||||
@ -212,6 +223,30 @@ class BaseChat(ABC):
|
||||
"""
|
||||
return self.parse_user_input()
|
||||
|
||||
async def prepare_input_values(self) -> Dict:
|
||||
"""Generate input value compatible with custom prompt to LLM
|
||||
|
||||
Please note that you must not perform any blocking operations in this function
|
||||
|
||||
Returns:
|
||||
a dictionary to be formatted by prompt template
|
||||
"""
|
||||
input_values = await self.generate_input_values()
|
||||
|
||||
# Mapping variable names: compatible with custom prompt template variable names
|
||||
# Get the input_variables of the current prompt
|
||||
input_variables = []
|
||||
if hasattr(self.prompt_template, "prompt") and hasattr(
|
||||
self.prompt_template.prompt, "input_variables"
|
||||
):
|
||||
input_variables = self.prompt_template.prompt.input_variables
|
||||
# Compatible with question and user_input
|
||||
if "question" in input_variables and "question" not in input_values:
|
||||
input_values["question"] = self.current_user_input
|
||||
if "user_input" in input_variables and "user_input" not in input_values:
|
||||
input_values["user_input"] = self.current_user_input
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Return the LLM client."""
|
||||
@ -313,7 +348,7 @@ class BaseChat(ABC):
|
||||
return user_params
|
||||
|
||||
async def _build_model_request(self) -> ModelRequest:
|
||||
input_values = await self.generate_input_values()
|
||||
input_values = await self.prepare_input_values()
|
||||
# Load history
|
||||
self.history_messages = self.current_message.get_history_message()
|
||||
self.current_message.start_new_round()
|
||||
|
@ -95,19 +95,6 @@ class ChatDashboard(BaseChat):
|
||||
"supported_chat_type": self.dashboard_template["supported_chart_type"],
|
||||
}
|
||||
|
||||
# Mapping variable names: compatible with custom prompt template variable names
|
||||
# Get the input_variables of the current prompt
|
||||
input_variables = []
|
||||
if hasattr(self.prompt_template, "prompt") and hasattr(
|
||||
self.prompt_template.prompt, "input_variables"
|
||||
):
|
||||
input_variables = self.prompt_template.prompt.input_variables
|
||||
# Compatible with question and user_input
|
||||
if "question" in input_variables:
|
||||
input_values["question"] = self.current_user_input
|
||||
if "user_input" in input_variables:
|
||||
input_values["user_input"] = self.current_user_input
|
||||
|
||||
return input_values
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
|
@ -139,6 +139,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
user_code=entity.user_code,
|
||||
model=entity.model,
|
||||
input_variables=entity.input_variables,
|
||||
response_schema=entity.response_schema,
|
||||
prompt_language=entity.prompt_language,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
|
Loading…
Reference in New Issue
Block a user