mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
feat(flow): Add Chat Data template (#2716)
This commit is contained in:
parent
7e7581e891
commit
f79f81ccc3
@ -137,18 +137,21 @@ def _initialize_openapi(system_app: SystemApp):
|
|||||||
|
|
||||||
|
|
||||||
def _initialize_operators():
|
def _initialize_operators():
|
||||||
from dbgpt_app.operators.code import CodeMapOperator # noqa: F401
|
from dbgpt.core.awel import BaseOperator
|
||||||
from dbgpt_app.operators.converter import StringToInteger # noqa: F401
|
from dbgpt.util.module_utils import ModelScanner, ScannerConfig
|
||||||
from dbgpt_app.operators.datasource import ( # noqa: F401
|
|
||||||
HODatasourceExecutorOperator,
|
modules = ["dbgpt_app.operators", "dbgpt_serve.agent.resource"]
|
||||||
HODatasourceRetrieverOperator,
|
|
||||||
)
|
scanner = ModelScanner[BaseOperator]()
|
||||||
from dbgpt_app.operators.llm import ( # noqa: F401
|
registered_items = {}
|
||||||
HOLLMOperator,
|
for module in modules:
|
||||||
HOStreamingLLMOperator,
|
config = ScannerConfig(
|
||||||
)
|
module_path=module,
|
||||||
from dbgpt_app.operators.rag import HOKnowledgeOperator # noqa: F401
|
base_class=BaseOperator,
|
||||||
from dbgpt_serve.agent.resource.datasource import DatasourceResource # noqa: F401
|
)
|
||||||
|
items = scanner.scan_and_register(config)
|
||||||
|
registered_items[module] = items
|
||||||
|
return scanner.get_registered_items()
|
||||||
|
|
||||||
|
|
||||||
def _initialize_code_server(system_app: SystemApp):
|
def _initialize_code_server(system_app: SystemApp):
|
||||||
|
@ -105,6 +105,7 @@ _PARAMETER_DATASOURCE = Parameter.build_from(
|
|||||||
type=DBResource,
|
type=DBResource,
|
||||||
description=_("The datasource to retrieve the context"),
|
description=_("The datasource to retrieve the context"),
|
||||||
)
|
)
|
||||||
|
|
||||||
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
||||||
_("Prompt Template"),
|
_("Prompt Template"),
|
||||||
"prompt_template",
|
"prompt_template",
|
||||||
@ -172,7 +173,7 @@ _OUTPUTS_SQL_RESULT = IOField.build_from(
|
|||||||
_("SQL result"),
|
_("SQL result"),
|
||||||
"sql_result",
|
"sql_result",
|
||||||
str,
|
str,
|
||||||
description=_("The result of the SQL execution"),
|
description=_("The result of the SQL execution(GPT-Vis format)"),
|
||||||
)
|
)
|
||||||
|
|
||||||
_INPUTS_SQL_DICT_LIST = IOField.build_from(
|
_INPUTS_SQL_DICT_LIST = IOField.build_from(
|
||||||
@ -189,7 +190,9 @@ _INPUTS_SQL_DICT_LIST = IOField.build_from(
|
|||||||
class GPTVisMixin:
|
class GPTVisMixin:
|
||||||
async def save_view_message(self, dag_ctx: DAGContext, view: str):
|
async def save_view_message(self, dag_ctx: DAGContext, view: str):
|
||||||
"""Save the view message."""
|
"""Save the view message."""
|
||||||
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
|
await dag_ctx.save_to_share_data(
|
||||||
|
BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view, overwrite=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||||
@ -286,6 +289,19 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
|||||||
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||||
"""Execute the context from the datasource."""
|
"""Execute the context from the datasource."""
|
||||||
|
|
||||||
|
_share_data_key = "__datasource_executor_result__"
|
||||||
|
|
||||||
|
class MarkdownMapper(MapOperator[str, str]):
|
||||||
|
async def map(self, context: str) -> str:
|
||||||
|
"""Convert the result to markdown."""
|
||||||
|
|
||||||
|
from dbgpt.util.pd_utils import df_to_markdown
|
||||||
|
|
||||||
|
df = await self.current_dag_context.get_from_share_data(
|
||||||
|
HODatasourceExecutorOperator._share_data_key
|
||||||
|
)
|
||||||
|
return df_to_markdown(df)
|
||||||
|
|
||||||
metadata = ViewMetadata(
|
metadata = ViewMetadata(
|
||||||
label=_("Datasource Executor Operator"),
|
label=_("Datasource Executor Operator"),
|
||||||
name="higher_order_datasource_executor_operator",
|
name="higher_order_datasource_executor_operator",
|
||||||
@ -293,7 +309,16 @@ class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
|||||||
category=OperatorCategory.DATABASE,
|
category=OperatorCategory.DATABASE,
|
||||||
parameters=[_PARAMETER_DATASOURCE.new()],
|
parameters=[_PARAMETER_DATASOURCE.new()],
|
||||||
inputs=[_INPUTS_SQL_DICT.new()],
|
inputs=[_INPUTS_SQL_DICT.new()],
|
||||||
outputs=[_OUTPUTS_SQL_RESULT.new()],
|
outputs=[
|
||||||
|
_OUTPUTS_SQL_RESULT.new(),
|
||||||
|
IOField.build_from(
|
||||||
|
_("Markdown result"),
|
||||||
|
"markdown_result",
|
||||||
|
str,
|
||||||
|
description=_("The markdown result of the SQL execution"),
|
||||||
|
mappers=[MarkdownMapper],
|
||||||
|
),
|
||||||
|
],
|
||||||
tags={"order": TAGS_ORDER_HIGH},
|
tags={"order": TAGS_ORDER_HIGH},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -314,8 +339,16 @@ class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
|||||||
sql = sql_dict.get("sql")
|
sql = sql_dict.get("sql")
|
||||||
if not sql:
|
if not sql:
|
||||||
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
|
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
|
||||||
|
|
||||||
|
thoughts = sql_dict.get("thoughts", "")
|
||||||
|
|
||||||
data_df = await self._datasource.query_to_df(sql)
|
data_df = await self._datasource.query_to_df(sql)
|
||||||
|
# Save the result to share data, for markdown mapper
|
||||||
|
await self.current_dag_context.save_to_share_data(
|
||||||
|
HODatasourceExecutorOperator._share_data_key, data_df
|
||||||
|
)
|
||||||
view = await vis.display(chart=sql_dict, data_df=data_df)
|
view = await vis.display(chart=sql_dict, data_df=data_df)
|
||||||
|
view = thoughts + "\n\n" + view
|
||||||
await self.save_view_message(self.current_dag_context, view)
|
await self.save_view_message(self.current_dag_context, view)
|
||||||
return view
|
return view
|
||||||
|
|
||||||
|
233
packages/dbgpt-app/src/dbgpt_app/operators/report.py
Normal file
233
packages/dbgpt-app/src/dbgpt_app/operators/report.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
from functools import cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dbgpt.core import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanPromptTemplate,
|
||||||
|
LLMClient,
|
||||||
|
ModelMessage,
|
||||||
|
SystemPromptTemplate,
|
||||||
|
)
|
||||||
|
from dbgpt.core.awel import JoinOperator
|
||||||
|
from dbgpt.core.awel.flow.base import (
|
||||||
|
TAGS_ORDER_HIGH,
|
||||||
|
IOField,
|
||||||
|
OperatorCategory,
|
||||||
|
Parameter,
|
||||||
|
ViewMetadata,
|
||||||
|
)
|
||||||
|
from dbgpt.core.interface.llm import ModelRequest
|
||||||
|
from dbgpt.model.operators import MixinLLMOperator
|
||||||
|
from dbgpt.util.i18n_utils import _
|
||||||
|
from dbgpt_app.operators.datasource import GPTVisMixin
|
||||||
|
|
||||||
|
_DEFAULT_PROMPT_EN = """You are a helpful AI assistant.
|
||||||
|
|
||||||
|
Please carefully read the data in the Markdown table format below, the data is a
|
||||||
|
database query result based on the user question. Please analyze and summarize the
|
||||||
|
data carefully, and provide a summary report in markdown format.
|
||||||
|
|
||||||
|
<data-report>
|
||||||
|
{data_report}
|
||||||
|
</data-report>
|
||||||
|
|
||||||
|
user question:
|
||||||
|
{user_input}
|
||||||
|
|
||||||
|
Please answer in the same language as the user's question.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_PROMPT_ZH = """你是一个有用的AI助手。
|
||||||
|
|
||||||
|
请你仔细阅读下面的 Markdown 表格格式的数据,这是一份根据用户问题查询到的数据库的数据,\
|
||||||
|
你需要根据数据仔细分析和总结,给出一份总结报告,使用 markdown 格式输出。
|
||||||
|
|
||||||
|
<data-report>
|
||||||
|
{data_report}
|
||||||
|
</data-report>
|
||||||
|
|
||||||
|
用户的问题:
|
||||||
|
{user_input}
|
||||||
|
|
||||||
|
请用用户提问的语言回答。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_USER_PROMPT = """\
|
||||||
|
{user_input}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _get_default_prompt(language: str) -> ChatPromptTemplate:
|
||||||
|
if language == "zh":
|
||||||
|
sys_prompt = _DEFAULT_PROMPT_ZH
|
||||||
|
user_prompt = _DEFAULT_USER_PROMPT
|
||||||
|
else:
|
||||||
|
sys_prompt = _DEFAULT_PROMPT_EN
|
||||||
|
user_prompt = _DEFAULT_USER_PROMPT
|
||||||
|
|
||||||
|
return ChatPromptTemplate(
|
||||||
|
messages=[
|
||||||
|
SystemPromptTemplate.from_template(sys_prompt),
|
||||||
|
HumanPromptTemplate.from_template(user_prompt),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReportAnalystOperator(MixinLLMOperator, JoinOperator[str]):
|
||||||
|
metadata = ViewMetadata(
|
||||||
|
label=_("Report Analyst"),
|
||||||
|
name="report_analyst",
|
||||||
|
description=_("Report Analyst"),
|
||||||
|
category=OperatorCategory.DATABASE,
|
||||||
|
tags={"order": TAGS_ORDER_HIGH},
|
||||||
|
parameters=[
|
||||||
|
Parameter.build_from(
|
||||||
|
_("Prompt Template"),
|
||||||
|
"prompt_template",
|
||||||
|
ChatPromptTemplate,
|
||||||
|
description=_("The prompt template for the conversation."),
|
||||||
|
optional=True,
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
Parameter.build_from(
|
||||||
|
_("Model Name"),
|
||||||
|
"model",
|
||||||
|
str,
|
||||||
|
optional=True,
|
||||||
|
default=None,
|
||||||
|
description=_("The model name."),
|
||||||
|
),
|
||||||
|
Parameter.build_from(
|
||||||
|
_("LLM Client"),
|
||||||
|
"llm_client",
|
||||||
|
LLMClient,
|
||||||
|
optional=True,
|
||||||
|
default=None,
|
||||||
|
description=_(
|
||||||
|
"The LLM Client, how to connect to the LLM model, if not provided,"
|
||||||
|
" it will use the default client deployed by DB-GPT."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
IOField.build_from(
|
||||||
|
_("User question"),
|
||||||
|
"question",
|
||||||
|
str,
|
||||||
|
description=_("The question of user"),
|
||||||
|
),
|
||||||
|
IOField.build_from(
|
||||||
|
_("The data report"),
|
||||||
|
"data_report",
|
||||||
|
str,
|
||||||
|
_("The data report in markdown format."),
|
||||||
|
dynamic=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IOField.build_from(
|
||||||
|
_("Report Analyst Result"),
|
||||||
|
"report_analyst_result",
|
||||||
|
str,
|
||||||
|
description=_("The report analyst result."),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_template: Optional[ChatPromptTemplate] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
llm_client: Optional[LLMClient] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
|
||||||
|
MixinLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
||||||
|
|
||||||
|
# User must select a history merge mode
|
||||||
|
self._prompt_template = prompt_template
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_template(self) -> ChatPromptTemplate:
|
||||||
|
"""Get the prompt template."""
|
||||||
|
language = "en"
|
||||||
|
if self.system_app:
|
||||||
|
language = self.system_app.config.get_current_lang()
|
||||||
|
if self._prompt_template is None:
|
||||||
|
return _get_default_prompt(language)
|
||||||
|
return self._prompt_template
|
||||||
|
|
||||||
|
async def _join_func(self, question: str, data_report: str, *args):
|
||||||
|
dynamic_inputs = [data_report]
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, str):
|
||||||
|
dynamic_inputs.append(arg)
|
||||||
|
data_report = "\n".join(dynamic_inputs)
|
||||||
|
messages = self.prompt_template.format_messages(
|
||||||
|
user_input=question,
|
||||||
|
data_report=data_report,
|
||||||
|
)
|
||||||
|
model_messages = ModelMessage.from_base_messages(messages)
|
||||||
|
models = await self.llm_client.models()
|
||||||
|
if not models:
|
||||||
|
raise Exception("No models available.")
|
||||||
|
model = self._model or models[0].model
|
||||||
|
|
||||||
|
model_request = ModelRequest.build_request(model, messages=model_messages)
|
||||||
|
model_output = await self.llm_client.generate(model_request)
|
||||||
|
text = model_output.gen_text_with_thinking()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class StringJoinOperator(GPTVisMixin, JoinOperator[str]):
|
||||||
|
"""Join operator for strings.
|
||||||
|
This operator joins the input strings with a specified separator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = ViewMetadata(
|
||||||
|
label=_("String Join Operator"),
|
||||||
|
name="string_join_operator",
|
||||||
|
description=_("Merge multiple inputs into a single string."),
|
||||||
|
category=OperatorCategory.COMMON,
|
||||||
|
parameters=[
|
||||||
|
Parameter.build_from(
|
||||||
|
_("Separator"),
|
||||||
|
"separator",
|
||||||
|
str,
|
||||||
|
optional=True,
|
||||||
|
default="\n\n",
|
||||||
|
description=_("The separator to join the strings."),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
IOField.build_from(
|
||||||
|
_("Input Strings"),
|
||||||
|
"input_strings",
|
||||||
|
str,
|
||||||
|
description=_("The input strings to join."),
|
||||||
|
dynamic=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IOField.build_from(
|
||||||
|
_("Joined String"),
|
||||||
|
"joined_string",
|
||||||
|
str,
|
||||||
|
description=_("The joined string."),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tags={"order": TAGS_ORDER_HIGH},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, separator: str = "\n\n", **kwargs):
|
||||||
|
super().__init__(combine_function=self._join_func, **kwargs)
|
||||||
|
self.separator = separator
|
||||||
|
|
||||||
|
async def _join_func(self, *args) -> str:
|
||||||
|
"""Join the strings with the separator."""
|
||||||
|
view = self.separator.join(args)
|
||||||
|
await self.save_view_message(self.current_dag_context, view)
|
||||||
|
return view
|
@ -20,3 +20,25 @@ def csv_colunm_foramt(val):
|
|||||||
return val
|
return val
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def df_to_markdown(df: pd.DataFrame, index=False) -> str:
|
||||||
|
"""Convert a pandas DataFrame to a Markdown table."""
|
||||||
|
columns = df.columns
|
||||||
|
header = "| " + " | ".join(columns) + " |"
|
||||||
|
separator = "| " + " | ".join(["---"] * len(columns)) + " |"
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
row_str = "| " + " | ".join(map(str, row.values)) + " |"
|
||||||
|
rows.append(row_str)
|
||||||
|
|
||||||
|
if index:
|
||||||
|
header = "| index | " + " | ".join(columns) + " |"
|
||||||
|
separator = "| --- | " + " | ".join(["---"] * len(columns)) + " |"
|
||||||
|
rows = []
|
||||||
|
for idx, row in df.iterrows():
|
||||||
|
row_str = f"| {idx} | " + " | ".join(map(str, row.values)) + " |"
|
||||||
|
rows.append(row_str)
|
||||||
|
|
||||||
|
return "\n".join([header, separator] + rows)
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,22 +1,23 @@
|
|||||||
"""
|
"""
|
||||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
|
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
|
||||||
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
|
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
|
||||||
mysql -h 127.0.0.1 -uroot -p -P3307
|
mysql -h 127.0.0.1 -uroot -p -P3307
|
||||||
Enter password:
|
Enter password:
|
||||||
Welcome to the MySQL monitor. Commands end with ; or \g.
|
Welcome to the MySQL monitor. Commands end with ; or \g.
|
||||||
Your MySQL connection id is 2
|
Your MySQL connection id is 2
|
||||||
Server version: 5.7.41 MySQL Community Server (GPL)
|
Server version: 5.7.41 MySQL Community Server (GPL)
|
||||||
|
|
||||||
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
|
||||||
|
|
||||||
Oracle is a registered trademark of Oracle Corporation and/or its
|
Oracle is a registered trademark of Oracle Corporation and/or its
|
||||||
affiliates. Other names may be trademarks of their respective
|
affiliates. Other names may be trademarks of their respective
|
||||||
owners.
|
owners.
|
||||||
|
|
||||||
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
|
||||||
|
|
||||||
> create database test;
|
> create database test;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector
|
from dbgpt_ext.datasource.rdbms.conn_oracle import OracleConnector
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ CREATE TABLE test (
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db():
|
def db():
|
||||||
# 注意:Oracle 默认端口是 1521,连接方式建议用 service_name
|
# 注意:Oracle 默认端口是 1521,连接方式建议用 service_name
|
||||||
@ -45,32 +47,37 @@ def db():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # 如果表不存在也忽略错误
|
pass # 如果表不存在也忽略错误
|
||||||
|
|
||||||
|
|
||||||
def test_get_usable_table_names(db):
|
def test_get_usable_table_names(db):
|
||||||
db.run(_create_table_sql)
|
db.run(_create_table_sql)
|
||||||
db.run("COMMIT")
|
db.run("COMMIT")
|
||||||
table_names = db.get_usable_table_names()
|
table_names = db.get_usable_table_names()
|
||||||
assert "TEST" in map(str.upper, table_names)
|
assert "TEST" in map(str.upper, table_names)
|
||||||
|
|
||||||
|
|
||||||
def test_get_table_info(db):
|
def test_get_table_info(db):
|
||||||
db.run(_create_table_sql)
|
db.run(_create_table_sql)
|
||||||
db.run("COMMIT")
|
db.run("COMMIT")
|
||||||
table_info = db.get_table_info()
|
table_info = db.get_table_info()
|
||||||
assert "CREATE TABLE TEST" in table_info.upper()
|
assert "CREATE TABLE TEST" in table_info.upper()
|
||||||
|
|
||||||
|
|
||||||
def test_run_no_throw(db):
|
def test_run_no_throw(db):
|
||||||
result = db.run_no_throw("this is a error sql")
|
result = db.run_no_throw("this is a error sql")
|
||||||
# run_no_throw 返回的是 list,错误时为空
|
# run_no_throw 返回的是 list,错误时为空
|
||||||
assert result == [] or isinstance(result, list)
|
assert result == [] or isinstance(result, list)
|
||||||
|
|
||||||
|
|
||||||
def test_get_index_empty(db):
|
def test_get_index_empty(db):
|
||||||
db.run(_create_table_sql)
|
db.run(_create_table_sql)
|
||||||
db.run("COMMIT")
|
db.run("COMMIT")
|
||||||
indexes = db.get_indexes("TEST")
|
indexes = db.get_indexes("TEST")
|
||||||
assert indexes == []
|
assert indexes == []
|
||||||
|
|
||||||
|
|
||||||
def test_get_fields(db):
|
def test_get_fields(db):
|
||||||
#db.run(_create_table_sql)
|
# db.run(_create_table_sql)
|
||||||
#db.run("COMMIT")
|
# db.run("COMMIT")
|
||||||
print("进入方法...")
|
print("进入方法...")
|
||||||
fields = db.get_fields("PY_TEST")
|
fields = db.get_fields("PY_TEST")
|
||||||
print("正在打印字段信息...")
|
print("正在打印字段信息...")
|
||||||
@ -81,16 +88,21 @@ def test_get_fields(db):
|
|||||||
print(f"Is Nullable: {field[3]}")
|
print(f"Is Nullable: {field[3]}")
|
||||||
print(f"Column Comment: {field[4]}")
|
print(f"Column Comment: {field[4]}")
|
||||||
print("-" * 30) # 可选的分隔符
|
print("-" * 30) # 可选的分隔符
|
||||||
#assert fields[0][0].upper() == "ID"
|
# assert fields[0][0].upper() == "ID"
|
||||||
|
|
||||||
|
|
||||||
def test_get_charset(db):
|
def test_get_charset(db):
|
||||||
result = db.run("SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'")
|
result = db.run(
|
||||||
|
"SELECT VALUE FROM NLS_DATABASE_PARAMETERS WHERE PARAMETER = 'NLS_CHARACTERSET'"
|
||||||
|
)
|
||||||
assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组
|
assert result[1][0] in ("AL32UTF8", "UTF8") # result[0] 是字段名元组
|
||||||
|
|
||||||
|
|
||||||
def test_get_users(db):
|
def test_get_users(db):
|
||||||
users = db.get_users()
|
users = db.get_users()
|
||||||
assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users)
|
assert any(user[0].upper() in ("SYS", "SYSTEM") for user in users)
|
||||||
|
|
||||||
|
|
||||||
def test_get_database_lists(db):
|
def test_get_database_lists(db):
|
||||||
cdb_result = db.run("SELECT CDB FROM V$DATABASE")
|
cdb_result = db.run("SELECT CDB FROM V$DATABASE")
|
||||||
if cdb_result[1][0] == "YES":
|
if cdb_result[1][0] == "YES":
|
||||||
@ -98,4 +110,4 @@ def test_get_database_lists(db):
|
|||||||
pdb_names = [name[0] for name in databases[1:]]
|
pdb_names = [name[0] for name in databases[1:]]
|
||||||
else:
|
else:
|
||||||
pdb_names = ["ORCL"]
|
pdb_names = ["ORCL"]
|
||||||
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)
|
assert any(name in ("ORCLPDB1", "ORCL") for name in pdb_names)
|
||||||
|
Loading…
Reference in New Issue
Block a user