mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
feat(flow): Add Chat Data template (#2716)
This commit is contained in:
parent
7e7581e891
commit
f79f81ccc3
@ -137,18 +137,21 @@ def _initialize_openapi(system_app: SystemApp):
|
||||
|
||||
|
||||
def _initialize_operators():
|
||||
from dbgpt_app.operators.code import CodeMapOperator # noqa: F401
|
||||
from dbgpt_app.operators.converter import StringToInteger # noqa: F401
|
||||
from dbgpt_app.operators.datasource import ( # noqa: F401
|
||||
HODatasourceExecutorOperator,
|
||||
HODatasourceRetrieverOperator,
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.util.module_utils import ModelScanner, ScannerConfig
|
||||
|
||||
modules = ["dbgpt_app.operators", "dbgpt_serve.agent.resource"]
|
||||
|
||||
scanner = ModelScanner[BaseOperator]()
|
||||
registered_items = {}
|
||||
for module in modules:
|
||||
config = ScannerConfig(
|
||||
module_path=module,
|
||||
base_class=BaseOperator,
|
||||
)
|
||||
from dbgpt_app.operators.llm import ( # noqa: F401
|
||||
HOLLMOperator,
|
||||
HOStreamingLLMOperator,
|
||||
)
|
||||
from dbgpt_app.operators.rag import HOKnowledgeOperator # noqa: F401
|
||||
from dbgpt_serve.agent.resource.datasource import DatasourceResource # noqa: F401
|
||||
items = scanner.scan_and_register(config)
|
||||
registered_items[module] = items
|
||||
return scanner.get_registered_items()
|
||||
|
||||
|
||||
def _initialize_code_server(system_app: SystemApp):
|
||||
|
@ -105,6 +105,7 @@ _PARAMETER_DATASOURCE = Parameter.build_from(
|
||||
type=DBResource,
|
||||
description=_("The datasource to retrieve the context"),
|
||||
)
|
||||
|
||||
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
@ -172,7 +173,7 @@ _OUTPUTS_SQL_RESULT = IOField.build_from(
|
||||
_("SQL result"),
|
||||
"sql_result",
|
||||
str,
|
||||
description=_("The result of the SQL execution"),
|
||||
description=_("The result of the SQL execution(GPT-Vis format)"),
|
||||
)
|
||||
|
||||
_INPUTS_SQL_DICT_LIST = IOField.build_from(
|
||||
@ -189,7 +190,9 @@ _INPUTS_SQL_DICT_LIST = IOField.build_from(
|
||||
class GPTVisMixin:
|
||||
async def save_view_message(self, dag_ctx: DAGContext, view: str):
|
||||
"""Save the view message."""
|
||||
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
|
||||
await dag_ctx.save_to_share_data(
|
||||
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view, overwrite=True
|
||||
)
|
||||
|
||||
|
||||
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
@ -286,6 +289,19 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||
"""Execute the context from the datasource."""
|
||||
|
||||
_share_data_key = "__datasource_executor_result__"
|
||||
|
||||
class MarkdownMapper(MapOperator[str, str]):
|
||||
async def map(self, context: str) -> str:
|
||||
"""Convert the result to markdown."""
|
||||
|
||||
from dbgpt.util.pd_utils import df_to_markdown
|
||||
|
||||
df = await self.current_dag_context.get_from_share_data(
|
||||
HODatasourceExecutorOperator._share_data_key
|
||||
)
|
||||
return df_to_markdown(df)
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Executor Operator"),
|
||||
name="higher_order_datasource_executor_operator",
|
||||
@ -293,7 +309,16 @@ class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||
category=OperatorCategory.DATABASE,
|
||||
parameters=[_PARAMETER_DATASOURCE.new()],
|
||||
inputs=[_INPUTS_SQL_DICT.new()],
|
||||
outputs=[_OUTPUTS_SQL_RESULT.new()],
|
||||
outputs=[
|
||||
_OUTPUTS_SQL_RESULT.new(),
|
||||
IOField.build_from(
|
||||
_("Markdown result"),
|
||||
"markdown_result",
|
||||
str,
|
||||
description=_("The markdown result of the SQL execution"),
|
||||
mappers=[MarkdownMapper],
|
||||
),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
@ -314,8 +339,16 @@ class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||
sql = sql_dict.get("sql")
|
||||
if not sql:
|
||||
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
|
||||
|
||||
thoughts = sql_dict.get("thoughts", "")
|
||||
|
||||
data_df = await self._datasource.query_to_df(sql)
|
||||
# Save the result to share data, for markdown mapper
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
HODatasourceExecutorOperator._share_data_key, data_df
|
||||
)
|
||||
view = await vis.display(chart=sql_dict, data_df=data_df)
|
||||
view = thoughts + "\n\n" + view
|
||||
await self.save_view_message(self.current_dag_context, view)
|
||||
return view
|
||||
|
||||
|
233
packages/dbgpt-app/src/dbgpt_app/operators/report.py
Normal file
233
packages/dbgpt-app/src/dbgpt_app/operators/report.py
Normal file
@ -0,0 +1,233 @@
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
LLMClient,
|
||||
ModelMessage,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.core.awel import JoinOperator
|
||||
from dbgpt.core.awel.flow.base import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.llm import ModelRequest
|
||||
from dbgpt.model.operators import MixinLLMOperator
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt_app.operators.datasource import GPTVisMixin
|
||||
|
||||
_DEFAULT_PROMPT_EN = """You are a helpful AI assistant.
|
||||
|
||||
Please carefully read the data in the Markdown table format below, the data is a
|
||||
database query result based on the user question. Please analyze and summarize the
|
||||
data carefully, and provide a summary report in markdown format.
|
||||
|
||||
<data-report>
|
||||
{data_report}
|
||||
</data-report>
|
||||
|
||||
user question:
|
||||
{user_input}
|
||||
|
||||
Please answer in the same language as the user's question.
|
||||
"""
|
||||
|
||||
_DEFAULT_PROMPT_ZH = """你是一个有用的AI助手。
|
||||
|
||||
请你仔细阅读下面的 Markdown 表格格式的数据,这是一份根据用户问题查询到的数据库的数据,\
|
||||
你需要根据数据仔细分析和总结,给出一份总结报告,使用 markdown 格式输出。
|
||||
|
||||
<data-report>
|
||||
{data_report}
|
||||
</data-report>
|
||||
|
||||
用户的问题:
|
||||
{user_input}
|
||||
|
||||
请用用户提问的语言回答。
|
||||
"""
|
||||
|
||||
_DEFAULT_USER_PROMPT = """\
|
||||
{user_input}
|
||||
"""
|
||||
|
||||
|
||||
@cache
|
||||
def _get_default_prompt(language: str) -> ChatPromptTemplate:
|
||||
if language == "zh":
|
||||
sys_prompt = _DEFAULT_PROMPT_ZH
|
||||
user_prompt = _DEFAULT_USER_PROMPT
|
||||
else:
|
||||
sys_prompt = _DEFAULT_PROMPT_EN
|
||||
user_prompt = _DEFAULT_USER_PROMPT
|
||||
|
||||
return ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(sys_prompt),
|
||||
HumanPromptTemplate.from_template(user_prompt),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ReportAnalystOperator(MixinLLMOperator, JoinOperator[str]):
|
||||
metadata = ViewMetadata(
|
||||
label=_("Report Analyst"),
|
||||
name="report_analyst",
|
||||
description=_("Report Analyst"),
|
||||
category=OperatorCategory.DATABASE,
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
ChatPromptTemplate,
|
||||
description=_("The prompt template for the conversation."),
|
||||
optional=True,
|
||||
default=None,
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Model Name"),
|
||||
"model",
|
||||
str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The model name."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("LLM Client"),
|
||||
"llm_client",
|
||||
LLMClient,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_(
|
||||
"The LLM Client, how to connect to the LLM model, if not provided,"
|
||||
" it will use the default client deployed by DB-GPT."
|
||||
),
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("User question"),
|
||||
"question",
|
||||
str,
|
||||
description=_("The question of user"),
|
||||
),
|
||||
IOField.build_from(
|
||||
_("The data report"),
|
||||
"data_report",
|
||||
str,
|
||||
_("The data report in markdown format."),
|
||||
dynamic=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Report Analyst Result"),
|
||||
"report_analyst_result",
|
||||
str,
|
||||
description=_("The report analyst result."),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_template: Optional[ChatPromptTemplate] = None,
|
||||
model: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
**kwargs,
|
||||
):
|
||||
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
|
||||
MixinLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
||||
|
||||
# User must select a history merge mode
|
||||
self._prompt_template = prompt_template
|
||||
self._model = model
|
||||
|
||||
@property
|
||||
def prompt_template(self) -> ChatPromptTemplate:
|
||||
"""Get the prompt template."""
|
||||
language = "en"
|
||||
if self.system_app:
|
||||
language = self.system_app.config.get_current_lang()
|
||||
if self._prompt_template is None:
|
||||
return _get_default_prompt(language)
|
||||
return self._prompt_template
|
||||
|
||||
async def _join_func(self, question: str, data_report: str, *args):
|
||||
dynamic_inputs = [data_report]
|
||||
for arg in args:
|
||||
if isinstance(arg, str):
|
||||
dynamic_inputs.append(arg)
|
||||
data_report = "\n".join(dynamic_inputs)
|
||||
messages = self.prompt_template.format_messages(
|
||||
user_input=question,
|
||||
data_report=data_report,
|
||||
)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
models = await self.llm_client.models()
|
||||
if not models:
|
||||
raise Exception("No models available.")
|
||||
model = self._model or models[0].model
|
||||
|
||||
model_request = ModelRequest.build_request(model, messages=model_messages)
|
||||
model_output = await self.llm_client.generate(model_request)
|
||||
text = model_output.gen_text_with_thinking()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class StringJoinOperator(GPTVisMixin, JoinOperator[str]):
|
||||
"""Join operator for strings.
|
||||
This operator joins the input strings with a specified separator.
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("String Join Operator"),
|
||||
name="string_join_operator",
|
||||
description=_("Merge multiple inputs into a single string."),
|
||||
category=OperatorCategory.COMMON,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Separator"),
|
||||
"separator",
|
||||
str,
|
||||
optional=True,
|
||||
default="\n\n",
|
||||
description=_("The separator to join the strings."),
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Input Strings"),
|
||||
"input_strings",
|
||||
str,
|
||||
description=_("The input strings to join."),
|
||||
dynamic=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Joined String"),
|
||||
"joined_string",
|
||||
str,
|
||||
description=_("The joined string."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, separator: str = "\n\n", **kwargs):
|
||||
super().__init__(combine_function=self._join_func, **kwargs)
|
||||
self.separator = separator
|
||||
|
||||
async def _join_func(self, *args) -> str:
|
||||
"""Join the strings with the separator."""
|
||||
view = self.separator.join(args)
|
||||
await self.save_view_message(self.current_dag_context, view)
|
||||
return view
|
@ -20,3 +20,25 @@ def csv_colunm_foramt(val):
|
||||
return val
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
|
||||
def df_to_markdown(df: pd.DataFrame, index=False) -> str:
|
||||
"""Convert a pandas DataFrame to a Markdown table."""
|
||||
columns = df.columns
|
||||
header = "| " + " | ".join(columns) + " |"
|
||||
separator = "| " + " | ".join(["---"] * len(columns)) + " |"
|
||||
|
||||
rows = []
|
||||
for _, row in df.iterrows():
|
||||
row_str = "| " + " | ".join(map(str, row.values)) + " |"
|
||||
rows.append(row_str)
|
||||
|
||||
if index:
|
||||
header = "| index | " + " | ".join(columns) + " |"
|
||||
separator = "| --- | " + " | ".join(["---"] * len(columns)) + " |"
|
||||
rows = []
|
||||
for idx, row in df.iterrows():
|
||||
row_str = f"| {idx} | " + " | ".join(map(str, row.values)) + " |"
|
||||
rows.append(row_str)
|
||||
|
||||
return "\n".join([header, separator] + rows)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,22 +1,23 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
|
||||
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
|
||||
mysql -h 127.0.0.1 -uroot -p -P3307
|
||||
Enter password:
|
||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||
Your MySQL connection id is 2
|
||||
Server version: 5.7.41 MySQL Community Server (GPL)
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
|
||||
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
|
||||
mysql -h 127.0.0.1 -uroot -p -P3307
|
||||
Enter password:
|
||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||
Your MySQL connection id is 2
|
||||
Server version: 5.7.41 MySQL Community Server (GPL)
|
||||
|
||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||
|
||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||
affiliates. Other names may be trademarks of their respective
|
||||
owners.
|
||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||
affiliates. Other names may be trademarks of their respective
|
||||
owners.
|
||||
|
||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||
|
||||
> create database test;
|
||||
> create database test;
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector
|
||||
|
||||
@ -27,6 +28,7 @@ CREATE TABLE test (
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
# 注意:Oracle 默认端口是 1521,连接方式建议用 service_name
|
||||
@ -45,32 +47,37 @@ def db():
|
||||
except Exception:
|
||||
pass # 如果表不存在也忽略错误
|
||||
|
||||
|
||||
def test_get_usable_table_names(db):
|
||||
db.run(_create_table_sql)
|
||||
db.run("COMMIT")
|
||||
table_names = db.get_usable_table_names()
|
||||
assert "TEST" in map(str.upper, table_names)
|
||||
|
||||
|
||||
def test_get_table_info(db):
|
||||
db.run(_create_table_sql)
|
||||
db.run("COMMIT")
|
||||
table_info = db.get_table_info()
|
||||
assert "CREATE TABLE TEST" in table_info.upper()
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
result = db.run_no_throw("this is a error sql")
|
||||
# run_no_throw 返回的是 list,错误时为空
|
||||
assert result == [] or isinstance(result, list)
|
||||
|
||||
|
||||
def test_get_index_empty(db):
|
||||
db.run(_create_table_sql)
|
||||
db.run("COMMIT")
|
||||
indexes = db.get_indexes("TEST")
|
||||
assert indexes == []
|
||||
|
||||
|
||||
def test_get_fields(db):
|
||||
#db.run(_create_table_sql)
|
||||
#db.run("COMMIT")
|
||||
# db.run(_create_table_sql)
|
||||
# db.run("COMMIT")
|
||||
print("进入方法...")
|
||||
fields = db.get_fields("PY_TEST")
|
||||
print("正在打印字段信息...")
|
||||
@ -81,16 +88,21 @@ def test_get_fields(db):
|
||||
print(f"Is Nullable: {field[3]}")
|
||||
print(f"Column Comment: {field[4]}")
|
||||
print("-" * 30) # 可选的分隔符
|
||||
#assert fields[0][0].upper() == "ID"
|
||||
# assert fields[0][0].upper() == "ID"
|
||||
|
||||
|
||||
def test_get_charset(db):
|
||||
result = db.run("SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'")
|
||||
result = db.run(
|
||||
"SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'"
|
||||
)
|
||||
assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组
|
||||
|
||||
|
||||
def test_get_users(db):
|
||||
users = db.get_users()
|
||||
assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users)
|
||||
|
||||
|
||||
def test_get_database_lists(db):
|
||||
cdb_result = db.run("SELECT CDB FROM V$DATABASE")
|
||||
if cdb_result[1][0] == "YES":
|
||||
|
Loading…
Reference in New Issue
Block a user