feat(flow): Add Chat Data template (#2716)

This commit is contained in:
Fangyin Cheng 2025-05-22 15:52:24 +08:00 committed by GitHub
parent 7e7581e891
commit f79f81ccc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 3353 additions and 34 deletions

View File

@ -137,18 +137,21 @@ def _initialize_openapi(system_app: SystemApp):
def _initialize_operators():
from dbgpt_app.operators.code import CodeMapOperator # noqa: F401
from dbgpt_app.operators.converter import StringToInteger # noqa: F401
from dbgpt_app.operators.datasource import ( # noqa: F401
HODatasourceExecutorOperator,
HODatasourceRetrieverOperator,
)
from dbgpt_app.operators.llm import ( # noqa: F401
HOLLMOperator,
HOStreamingLLMOperator,
)
from dbgpt_app.operators.rag import HOKnowledgeOperator # noqa: F401
from dbgpt_serve.agent.resource.datasource import DatasourceResource # noqa: F401
from dbgpt.core.awel import BaseOperator
from dbgpt.util.module_utils import ModelScanner, ScannerConfig
modules = ["dbgpt_app.operators", "dbgpt_serve.agent.resource"]
scanner = ModelScanner[BaseOperator]()
registered_items = {}
for module in modules:
config = ScannerConfig(
module_path=module,
base_class=BaseOperator,
)
items = scanner.scan_and_register(config)
registered_items[module] = items
return scanner.get_registered_items()
def _initialize_code_server(system_app: SystemApp):

View File

@ -105,6 +105,7 @@ _PARAMETER_DATASOURCE = Parameter.build_from(
type=DBResource,
description=_("The datasource to retrieve the context"),
)
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
_("Prompt Template"),
"prompt_template",
@ -172,7 +173,7 @@ _OUTPUTS_SQL_RESULT = IOField.build_from(
_("SQL result"),
"sql_result",
str,
description=_("The result of the SQL execution"),
description=_("The result of the SQL execution(GPT-Vis format)"),
)
_INPUTS_SQL_DICT_LIST = IOField.build_from(
@ -189,7 +190,9 @@ _INPUTS_SQL_DICT_LIST = IOField.build_from(
class GPTVisMixin:
async def save_view_message(self, dag_ctx: DAGContext, view: str):
"""Save the view message."""
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
await dag_ctx.save_to_share_data(
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view, overwrite=True
)
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
@ -286,6 +289,19 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
"""Execute the context from the datasource."""
_share_data_key = "__datasource_executor_result__"
class MarkdownMapper(MapOperator[str, str]):
async def map(self, context: str) -> str:
"""Convert the result to markdown."""
from dbgpt.util.pd_utils import df_to_markdown
df = await self.current_dag_context.get_from_share_data(
HODatasourceExecutorOperator._share_data_key
)
return df_to_markdown(df)
metadata = ViewMetadata(
label=_("Datasource Executor Operator"),
name="higher_order_datasource_executor_operator",
@ -293,7 +309,16 @@ class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
category=OperatorCategory.DATABASE,
parameters=[_PARAMETER_DATASOURCE.new()],
inputs=[_INPUTS_SQL_DICT.new()],
outputs=[_OUTPUTS_SQL_RESULT.new()],
outputs=[
_OUTPUTS_SQL_RESULT.new(),
IOField.build_from(
_("Markdown result"),
"markdown_result",
str,
description=_("The markdown result of the SQL execution"),
mappers=[MarkdownMapper],
),
],
tags={"order": TAGS_ORDER_HIGH},
)
@ -314,8 +339,16 @@ class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
sql = sql_dict.get("sql")
if not sql:
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
thoughts = sql_dict.get("thoughts", "")
data_df = await self._datasource.query_to_df(sql)
# Save the result to share data, for markdown mapper
await self.current_dag_context.save_to_share_data(
HODatasourceExecutorOperator._share_data_key, data_df
)
view = await vis.display(chart=sql_dict, data_df=data_df)
view = thoughts + "\n\n" + view
await self.save_view_message(self.current_dag_context, view)
return view

View File

@ -0,0 +1,233 @@
from functools import cache
from typing import Optional
from dbgpt.core import (
ChatPromptTemplate,
HumanPromptTemplate,
LLMClient,
ModelMessage,
SystemPromptTemplate,
)
from dbgpt.core.awel import JoinOperator
from dbgpt.core.awel.flow.base import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
Parameter,
ViewMetadata,
)
from dbgpt.core.interface.llm import ModelRequest
from dbgpt.model.operators import MixinLLMOperator
from dbgpt.util.i18n_utils import _
from dbgpt_app.operators.datasource import GPTVisMixin
_DEFAULT_PROMPT_EN = """You are a helpful AI assistant.
Please carefully read the data in the Markdown table format below, the data is a
database query result based on the user question. Please analyze and summarize the
data carefully, and provide a summary report in markdown format.
<data-report>
{data_report}
</data-report>
user question:
{user_input}
Please answer in the same language as the user's question.
"""
_DEFAULT_PROMPT_ZH = """你是一个有用的AI助手。
请你仔细阅读下面的 Markdown 表格格式的数据这是一份根据用户问题查询到的数据库的数据\
你需要根据数据仔细分析和总结给出一份总结报告使用 markdown 格式输出
<data-report>
{data_report}
</data-report>
用户的问题:
{user_input}
请用用户提问的语言回答
"""
_DEFAULT_USER_PROMPT = """\
{user_input}
"""
@cache
def _get_default_prompt(language: str) -> ChatPromptTemplate:
if language == "zh":
sys_prompt = _DEFAULT_PROMPT_ZH
user_prompt = _DEFAULT_USER_PROMPT
else:
sys_prompt = _DEFAULT_PROMPT_EN
user_prompt = _DEFAULT_USER_PROMPT
return ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(sys_prompt),
HumanPromptTemplate.from_template(user_prompt),
]
)
class ReportAnalystOperator(MixinLLMOperator, JoinOperator[str]):
metadata = ViewMetadata(
label=_("Report Analyst"),
name="report_analyst",
description=_("Report Analyst"),
category=OperatorCategory.DATABASE,
tags={"order": TAGS_ORDER_HIGH},
parameters=[
Parameter.build_from(
_("Prompt Template"),
"prompt_template",
ChatPromptTemplate,
description=_("The prompt template for the conversation."),
optional=True,
default=None,
),
Parameter.build_from(
_("Model Name"),
"model",
str,
optional=True,
default=None,
description=_("The model name."),
),
Parameter.build_from(
_("LLM Client"),
"llm_client",
LLMClient,
optional=True,
default=None,
description=_(
"The LLM Client, how to connect to the LLM model, if not provided,"
" it will use the default client deployed by DB-GPT."
),
),
],
inputs=[
IOField.build_from(
_("User question"),
"question",
str,
description=_("The question of user"),
),
IOField.build_from(
_("The data report"),
"data_report",
str,
_("The data report in markdown format."),
dynamic=True,
),
],
outputs=[
IOField.build_from(
_("Report Analyst Result"),
"report_analyst_result",
str,
description=_("The report analyst result."),
)
],
)
def __init__(
self,
prompt_template: Optional[ChatPromptTemplate] = None,
model: Optional[str] = None,
llm_client: Optional[LLMClient] = None,
**kwargs,
):
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
MixinLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
# User must select a history merge mode
self._prompt_template = prompt_template
self._model = model
@property
def prompt_template(self) -> ChatPromptTemplate:
"""Get the prompt template."""
language = "en"
if self.system_app:
language = self.system_app.config.get_current_lang()
if self._prompt_template is None:
return _get_default_prompt(language)
return self._prompt_template
async def _join_func(self, question: str, data_report: str, *args):
dynamic_inputs = [data_report]
for arg in args:
if isinstance(arg, str):
dynamic_inputs.append(arg)
data_report = "\n".join(dynamic_inputs)
messages = self.prompt_template.format_messages(
user_input=question,
data_report=data_report,
)
model_messages = ModelMessage.from_base_messages(messages)
models = await self.llm_client.models()
if not models:
raise Exception("No models available.")
model = self._model or models[0].model
model_request = ModelRequest.build_request(model, messages=model_messages)
model_output = await self.llm_client.generate(model_request)
text = model_output.gen_text_with_thinking()
return text
class StringJoinOperator(GPTVisMixin, JoinOperator[str]):
"""Join operator for strings.
This operator joins the input strings with a specified separator.
"""
metadata = ViewMetadata(
label=_("String Join Operator"),
name="string_join_operator",
description=_("Merge multiple inputs into a single string."),
category=OperatorCategory.COMMON,
parameters=[
Parameter.build_from(
_("Separator"),
"separator",
str,
optional=True,
default="\n\n",
description=_("The separator to join the strings."),
),
],
inputs=[
IOField.build_from(
_("Input Strings"),
"input_strings",
str,
description=_("The input strings to join."),
dynamic=True,
),
],
outputs=[
IOField.build_from(
_("Joined String"),
"joined_string",
str,
description=_("The joined string."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, separator: str = "\n\n", **kwargs):
super().__init__(combine_function=self._join_func, **kwargs)
self.separator = separator
async def _join_func(self, *args) -> str:
"""Join the strings with the separator."""
view = self.separator.join(args)
await self.save_view_message(self.current_dag_context, view)
return view

View File

@ -20,3 +20,25 @@ def csv_colunm_foramt(val):
return val
except ValueError:
return val
def df_to_markdown(df: pd.DataFrame, index=False) -> str:
"""Convert a pandas DataFrame to a Markdown table."""
columns = df.columns
header = "| " + " | ".join(columns) + " |"
separator = "| " + " | ".join(["---"] * len(columns)) + " |"
rows = []
for _, row in df.iterrows():
row_str = "| " + " | ".join(map(str, row.values)) + " |"
rows.append(row_str)
if index:
header = "| index | " + " | ".join(columns) + " |"
separator = "| --- | " + " | ".join(["---"] * len(columns)) + " |"
rows = []
for idx, row in df.iterrows():
row_str = f"| {idx} | " + " | ".join(map(str, row.values)) + " |"
rows.append(row_str)
return "\n".join([header, separator] + rows)

View File

@ -1,22 +1,23 @@
"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
mysql -h 127.0.0.1 -uroot -p -P3307
Enter password:
Welcome to the MySQL monitor. Commands end with ; or \g.
Your MySQL connection id is 2
Server version: 5.7.41 MySQL Community Server (GPL)
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
mysql -h 127.0.0.1 -uroot -p -P3307
Enter password:
Welcome to the MySQL monitor. Commands end with ; or \g.
Your MySQL connection id is 2
Server version: 5.7.41 MySQL Community Server (GPL)
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
Oracle is a registered trademark of Oracle Corporation and/or its
affiliates. Other names may be trademarks of their respective
owners.
Oracle is a registered trademark of Oracle Corporation and/or its
affiliates. Other names may be trademarks of their respective
owners.
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
> create database test;
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
> create database test;
"""
import pytest
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector
@ -27,6 +28,7 @@ CREATE TABLE test (
)
"""
@pytest.fixture
def db():
# 注意Oracle 默认端口是 1521连接方式建议用 service_name
@ -45,32 +47,37 @@ def db():
except Exception:
pass # 如果表不存在也忽略错误
def test_get_usable_table_names(db):
db.run(_create_table_sql)
db.run("COMMIT")
table_names = db.get_usable_table_names()
assert "TEST" in map(str.upper, table_names)
def test_get_table_info(db):
db.run(_create_table_sql)
db.run("COMMIT")
table_info = db.get_table_info()
assert "CREATE TABLE TEST" in table_info.upper()
def test_run_no_throw(db):
result = db.run_no_throw("this is a error sql")
# run_no_throw 返回的是 list错误时为空
assert result == [] or isinstance(result, list)
def test_get_index_empty(db):
db.run(_create_table_sql)
db.run("COMMIT")
indexes = db.get_indexes("TEST")
assert indexes == []
def test_get_fields(db):
#db.run(_create_table_sql)
#db.run("COMMIT")
# db.run(_create_table_sql)
# db.run("COMMIT")
print("进入方法...")
fields = db.get_fields("PY_TEST")
print("正在打印字段信息...")
@ -81,16 +88,21 @@ def test_get_fields(db):
print(f"Is Nullable: {field[3]}")
print(f"Column Comment: {field[4]}")
print("-" * 30) # 可选的分隔符
#assert fields[0][0].upper() == "ID"
# assert fields[0][0].upper() == "ID"
def test_get_charset(db):
result = db.run("SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'")
result = db.run(
"SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'"
)
assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组
def test_get_users(db):
users = db.get_users()
assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users)
def test_get_database_lists(db):
cdb_result = db.run("SELECT CDB FROM V$DATABASE")
if cdb_result[1][0] == "YES":
@ -98,4 +110,4 @@ def test_get_database_lists(db):
pdb_names = [name[0] for name in databases[1:]]
else:
pdb_names = ["ORCL"]
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)