mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
Merge remote-tracking branch 'origin/test' into test
This commit is contained in:
commit
657484608e
@ -29,7 +29,7 @@ class DBType(Enum):
|
||||
Postgresql = DbInfo("postgresql")
|
||||
|
||||
def value(self):
|
||||
return self._value_.name;
|
||||
return self._value_.name
|
||||
|
||||
def is_file_db(self):
|
||||
return self._value_.is_file_db
|
||||
|
@ -293,7 +293,6 @@ class Database:
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
@ -346,7 +345,14 @@ class Database:
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
if d[0]
|
||||
not in [
|
||||
"information_schema",
|
||||
"performance_schema",
|
||||
"sys",
|
||||
"mysql",
|
||||
"knowledge_management",
|
||||
]
|
||||
]
|
||||
|
||||
def convert_sql_write_to_select(self, write_sql):
|
||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
class BaseConnect( ABC):
|
||||
class BaseConnect(ABC):
|
||||
def get_connect(self, db_name: str):
|
||||
pass
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DBConfig(BaseModel):
|
||||
db_type: str
|
||||
db_name: str
|
||||
@ -8,4 +9,4 @@ class DBConfig(BaseModel):
|
||||
db_port: int = 0
|
||||
db_user: str = ""
|
||||
db_pwd: str = ""
|
||||
comment: str = ""
|
||||
comment: str = ""
|
||||
|
@ -26,7 +26,16 @@ class DuckdbConnectConfig:
|
||||
)
|
||||
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
||||
|
||||
def add_url_db(self, db_name, db_type, db_host: str, db_port: int, db_user: str, db_pwd: str, comment: str = ""):
|
||||
def add_url_db(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_host: str,
|
||||
db_port: int,
|
||||
db_user: str,
|
||||
db_pwd: str,
|
||||
comment: str = "",
|
||||
):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
@ -41,7 +50,7 @@ class DuckdbConnectConfig:
|
||||
def get_file_db_name(self, path):
|
||||
try:
|
||||
conn = duckdb.connect(path)
|
||||
result = conn.execute('SELECT current_database()').fetchone()[0]
|
||||
result = conn.execute("SELECT current_database()").fetchone()[0]
|
||||
return result
|
||||
except Exception as e:
|
||||
raise "Unusable duckdb database path:" + path
|
||||
@ -60,9 +69,7 @@ class DuckdbConnectConfig:
|
||||
|
||||
def delete_db(self, db_name):
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM connect_config where db_name=?", [db_name]
|
||||
)
|
||||
cursor.execute("DELETE FROM connect_config where db_name=?", [db_name])
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
@ -87,9 +94,7 @@ class DuckdbConnectConfig:
|
||||
def get_db_list(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute(
|
||||
"SELECT db_name, db_type, comment FROM connect_config "
|
||||
)
|
||||
cursor.execute("SELECT db_name, db_type, comment FROM connect_config ")
|
||||
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
@ -102,15 +107,12 @@ class DuckdbConnectConfig:
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_db_names(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute(
|
||||
"SELECT db_name FROM connect_config "
|
||||
)
|
||||
cursor.execute("SELECT db_name FROM connect_config ")
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
data.append(row[0])
|
||||
return data
|
||||
return []
|
||||
return []
|
||||
|
@ -16,7 +16,6 @@ CFG = Config()
|
||||
|
||||
|
||||
class ConnectManager:
|
||||
|
||||
def get_all_subclasses(self, cls):
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
@ -41,8 +40,15 @@ class ConnectManager:
|
||||
if CFG.LOCAL_DB_HOST:
|
||||
# default mysql
|
||||
if CFG.LOCAL_DB_NAME:
|
||||
self.storage.add_url_db(CFG.LOCAL_DB_NAME, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
|
||||
self.storage.add_url_db(
|
||||
CFG.LOCAL_DB_NAME,
|
||||
DBType.Mysql.value(),
|
||||
CFG.LOCAL_DB_HOST,
|
||||
CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER,
|
||||
CFG.LOCAL_DB_PASSWORD,
|
||||
"",
|
||||
)
|
||||
else:
|
||||
# get all default mysql database
|
||||
default_mysql = Database.from_uri(
|
||||
@ -69,32 +75,42 @@ class ConnectManager:
|
||||
# )
|
||||
dbs = default_mysql.get_database_list()
|
||||
for name in dbs:
|
||||
self.storage.add_url_db(name, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
|
||||
self.storage.add_url_db(
|
||||
name,
|
||||
DBType.Mysql.value(),
|
||||
CFG.LOCAL_DB_HOST,
|
||||
CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER,
|
||||
CFG.LOCAL_DB_PASSWORD,
|
||||
"",
|
||||
)
|
||||
if CFG.LOCAL_DB_PATH:
|
||||
# default file db is duckdb
|
||||
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
|
||||
if db_name:
|
||||
self.storage.add_file_db(db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH)
|
||||
self.storage.add_file_db(
|
||||
db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH
|
||||
)
|
||||
|
||||
def get_connect(self, db_name):
|
||||
db_config = self.storage.get_db_config(db_name)
|
||||
db_type = DBType.of_db_type(db_config.get('db_type'))
|
||||
db_type = DBType.of_db_type(db_config.get("db_type"))
|
||||
connect_instance = self.get_cls_by_dbtype(db_type.value())
|
||||
if db_type.is_file_db():
|
||||
db_path = db_config.get('db_path')
|
||||
db_path = db_config.get("db_path")
|
||||
return connect_instance.from_file_path(db_path)
|
||||
else:
|
||||
db_host = db_config.get('db_host')
|
||||
db_port = db_config.get('db_port')
|
||||
db_user = db_config.get('db_user')
|
||||
db_pwd = db_config.get('db_pwd')
|
||||
return connect_instance.from_uri_db(db_host, db_port, db_user, db_pwd, db_name)
|
||||
db_host = db_config.get("db_host")
|
||||
db_port = db_config.get("db_port")
|
||||
db_user = db_config.get("db_user")
|
||||
db_pwd = db_config.get("db_pwd")
|
||||
return connect_instance.from_uri_db(
|
||||
db_host, db_port, db_user, db_pwd, db_name
|
||||
)
|
||||
|
||||
def get_db_list(self):
|
||||
return self.storage.get_db_list()
|
||||
|
||||
|
||||
def get_db_names(self):
|
||||
return self.storage.get_db_names()
|
||||
|
||||
@ -104,7 +120,17 @@ class ConnectManager:
|
||||
def add_db(self, db_info: DBConfig):
|
||||
db_type = DBType.of_db_type(db_info.db_type)
|
||||
if db_type.is_file_db():
|
||||
self.storage.add_file_db(db_info.db_name, db_info.db_type, db_info.file_path)
|
||||
self.storage.add_file_db(
|
||||
db_info.db_name, db_info.db_type, db_info.file_path
|
||||
)
|
||||
else:
|
||||
self.storage.add_url_db(db_info.db_name, db_info.db_type, db_info.db_host, db_info.db_port, db_info.db_user, db_info.db_pwd, db_info.comment)
|
||||
return True
|
||||
self.storage.add_url_db(
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
db_info.db_host,
|
||||
db_info.db_port,
|
||||
db_info.db_user,
|
||||
db_info.db_pwd,
|
||||
db_info.comment,
|
||||
)
|
||||
return True
|
||||
|
@ -11,11 +11,13 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
"""Connect Duckdb Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "duckdb"
|
||||
db_dialect: str = "duckdb"
|
||||
|
||||
@ -33,7 +35,7 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results =[]
|
||||
results = []
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
@ -49,11 +51,18 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
engine = create_engine('duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db')
|
||||
engine = create_engine(
|
||||
"duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db"
|
||||
)
|
||||
metadata = MetaData(engine)
|
||||
|
||||
results = engine.connect().execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
|
||||
results = (
|
||||
engine.connect()
|
||||
.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
.fetchall()
|
||||
)
|
||||
|
||||
print(str(results))
|
||||
|
||||
|
@ -25,14 +25,13 @@ class MSSQLConnect(RDBMSDatabase):
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"]
|
||||
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results =[]
|
||||
results = []
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
|
@ -36,19 +36,20 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
|
||||
class RDBMSDatabase(BaseConnect):
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
db_type: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@ -88,36 +89,35 @@ class RDBMSDatabase(BaseConnect):
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
cls.driver
|
||||
+ "://"
|
||||
+ CFG.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ CFG.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
+ "/"
|
||||
+ db_name
|
||||
cls.driver
|
||||
+ "://"
|
||||
+ CFG.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ CFG.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
+ "/"
|
||||
+ db_name
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
@ -141,7 +141,6 @@ class RDBMSDatabase(BaseConnect):
|
||||
def get_session(self):
|
||||
session = self._db_sessions()
|
||||
|
||||
|
||||
return session
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
@ -181,7 +180,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
@ -194,7 +193,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
@ -413,9 +412,9 @@ class RDBMSDatabase(BaseConnect):
|
||||
set_idx = parts.index("set")
|
||||
where_idx = parts.index("where")
|
||||
# 截取 `set` 子句中的字段名
|
||||
set_clause = parts[set_idx + 1: where_idx][0].split("=")[0].strip()
|
||||
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
|
||||
# 截取 `where` 之后的条件语句
|
||||
where_clause = " ".join(parts[where_idx + 1:])
|
||||
where_clause = " ".join(parts[where_idx + 1 :])
|
||||
# 返回一个select语句,它选择更新的数据
|
||||
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||
else:
|
||||
|
@ -23,7 +23,7 @@ from pilot.openapi.api_v1.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo
|
||||
ChatSceneVo,
|
||||
)
|
||||
from pilot.connections.db_conn_info import DBConfig
|
||||
from pilot.configs.config import Config
|
||||
@ -96,16 +96,16 @@ def knowledge_list():
|
||||
return params
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||
async def dialogue_list():
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/add", response_model=Result[bool])
|
||||
async def dialogue_list(db_config: DBConfig = Body() ):
|
||||
async def dialogue_list(db_config: DBConfig = Body()):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
|
||||
|
||||
|
||||
@router.post("/v1/chat/db/delete", response_model=Result[bool])
|
||||
async def dialogue_list(db_name: str = None):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
|
||||
@ -115,6 +115,7 @@ async def dialogue_list(db_name: str = None):
|
||||
async def db_support_types():
|
||||
return Result[str].succ(["mysql", "mssql", "duckdb"])
|
||||
|
||||
|
||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||
async def dialogue_list(user_id: str = None):
|
||||
dialogues: List = []
|
||||
|
@ -3,7 +3,15 @@ from typing import List
|
||||
|
||||
|
||||
class Scene:
|
||||
def __init__(self, code, name, describe, param_types: List = [], is_inner: bool = False, show_disable=False):
|
||||
def __init__(
|
||||
self,
|
||||
code,
|
||||
name,
|
||||
describe,
|
||||
param_types: List = [],
|
||||
is_inner: bool = False,
|
||||
show_disable=False,
|
||||
):
|
||||
self.code = code
|
||||
self.name = name
|
||||
self.describe = describe
|
||||
@ -31,7 +39,7 @@ class ChatScene(Enum):
|
||||
"Use tools through dialogue to accomplish your goals.",
|
||||
["Plugin Select"],
|
||||
False,
|
||||
True
|
||||
True,
|
||||
)
|
||||
ChatDefaultKnowledge = Scene(
|
||||
"chat_default_knowledge",
|
||||
@ -87,4 +95,4 @@ class ChatScene(Enum):
|
||||
return self._value_.param_types
|
||||
|
||||
def show_disable(self):
|
||||
return self._value_.show_disable
|
||||
return self._value_.show_disable
|
||||
|
@ -19,7 +19,7 @@ from pilot.scene.chat_dashboard.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
ChartData,
|
||||
ReportData,
|
||||
ValueItem
|
||||
ValueItem,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
@ -65,7 +65,9 @@ class ChatDashboard(BaseChat):
|
||||
|
||||
client = DBSummaryClient()
|
||||
try:
|
||||
table_infos = client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
table_infos = client.get_similar_tables(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
)
|
||||
print("dashboard vector find tables:{}", table_infos)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
@ -74,7 +76,7 @@ class ChatDashboard(BaseChat):
|
||||
"input": self.current_user_input,
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": self.database.table_simple_info(),
|
||||
"supported_chat_type": self.dashboard_template['supported_chart_type']
|
||||
"supported_chat_type": self.dashboard_template["supported_chart_type"]
|
||||
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
}
|
||||
|
||||
@ -85,7 +87,9 @@ class ChatDashboard(BaseChat):
|
||||
chart_datas: List[ChartData] = []
|
||||
for chart_item in prompt_response:
|
||||
try:
|
||||
field_names, datas = self.database.query_ex(self.db_connect, chart_item.sql)
|
||||
field_names, datas = self.database.query_ex(
|
||||
self.db_connect, chart_item.sql
|
||||
)
|
||||
values: List[ValueItem] = []
|
||||
data_map = {}
|
||||
field_map = {}
|
||||
@ -96,28 +100,45 @@ class ChatDashboard(BaseChat):
|
||||
if not data_map[field_name]:
|
||||
field_map.update({f"{field_name}": False})
|
||||
else:
|
||||
field_map.update({f"{field_name}": all(
|
||||
isinstance(item, (int, float, Decimal)) for item in data_map[field_name])})
|
||||
field_map.update(
|
||||
{
|
||||
f"{field_name}": all(
|
||||
isinstance(item, (int, float, Decimal))
|
||||
for item in data_map[field_name]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
for field_name in field_names[1:]:
|
||||
if not field_map[field_name]:
|
||||
print("more than 2 non-numeric column")
|
||||
else:
|
||||
for data in datas:
|
||||
value_item = ValueItem(name=data[0], type=field_name,
|
||||
value=data[field_names.index(field_name)])
|
||||
value_item = ValueItem(
|
||||
name=data[0],
|
||||
type=field_name,
|
||||
value=data[field_names.index(field_name)],
|
||||
)
|
||||
values.append(value_item)
|
||||
|
||||
chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()),
|
||||
chart_name=chart_item.title,
|
||||
chart_type=chart_item.showcase,
|
||||
chart_desc=chart_item.thoughts,
|
||||
chart_sql=chart_item.sql,
|
||||
column_name=field_names,
|
||||
values=values))
|
||||
chart_datas.append(
|
||||
ChartData(
|
||||
chart_uid=str(uuid.uuid1()),
|
||||
chart_name=chart_item.title,
|
||||
chart_type=chart_item.showcase,
|
||||
chart_desc=chart_item.thoughts,
|
||||
chart_sql=chart_item.sql,
|
||||
column_name=field_names,
|
||||
values=values,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO 修复流程
|
||||
print(str(e))
|
||||
|
||||
return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None,
|
||||
charts=chart_datas)
|
||||
return ReportData(
|
||||
conv_uid=self.chat_session_id,
|
||||
template_name=self.report_name,
|
||||
template_introduce=None,
|
||||
charts=chart_datas,
|
||||
)
|
||||
|
@ -3,16 +3,15 @@ from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
class ValueItem(BaseModel):
|
||||
name: str
|
||||
type: str = None
|
||||
value: float
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.type,
|
||||
"value": self.value
|
||||
}
|
||||
return {"name": self.name, "type": self.type, "value": self.value}
|
||||
|
||||
|
||||
class ChartData(BaseModel):
|
||||
chart_uid: str
|
||||
@ -32,10 +31,11 @@ class ChartData(BaseModel):
|
||||
"chart_desc": self.chart_desc,
|
||||
"chart_sql": self.chart_sql,
|
||||
"column_name": [str(item) for item in self.column_name],
|
||||
"values": [value.dict() for value in self.values],
|
||||
"style": self.style
|
||||
"values": [value.dict() for value in self.values],
|
||||
"style": self.style,
|
||||
}
|
||||
|
||||
|
||||
class ReportData(BaseModel):
|
||||
conv_uid: str
|
||||
template_name: str
|
||||
@ -47,5 +47,5 @@ class ReportData(BaseModel):
|
||||
"conv_uid": self.conv_uid,
|
||||
"template_name": self.template_name,
|
||||
"template_introduce": self.template_introduce,
|
||||
"charts": [chart.dict() for chart in self.charts]
|
||||
}
|
||||
"charts": [chart.dict() for chart in self.charts],
|
||||
}
|
||||
|
@ -52,5 +52,3 @@ prompt = PromptTemplate(
|
||||
),
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||
|
||||
|
||||
|
@ -44,16 +44,18 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient()
|
||||
try:
|
||||
table_infos = client.get_db_summary(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
|
||||
table_infos = client.get_db_summary(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
table_infos = self.database.table_simple_info()
|
||||
table_infos = self.database.table_simple_info()
|
||||
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"top_k": str(self.top_k),
|
||||
"dialect": self.database.dialect,
|
||||
"table_info": table_infos
|
||||
"table_info": table_infos,
|
||||
}
|
||||
return input_values
|
||||
|
||||
|
@ -48,8 +48,9 @@ class ChatWithDbQA(BaseChat):
|
||||
if self.db_name:
|
||||
client = DBSummaryClient()
|
||||
try:
|
||||
table_infos = client.get_db_summary(dbname=self.db_name, query=self.current_user_input,
|
||||
topk=self.top_k)
|
||||
table_infos = client.get_db_summary(
|
||||
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
table_infos = self.database.table_simple_info()
|
||||
|
@ -138,7 +138,6 @@ def get_simlar(q):
|
||||
return "\n".join(contents)
|
||||
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
for plugin in CFG.plugins:
|
||||
|
@ -58,7 +58,6 @@ class MysqlSummary(DBSummary):
|
||||
|
||||
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
|
||||
|
||||
|
||||
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
|
||||
users=self.db.get_users(),
|
||||
grant=self.db.get_grants(),
|
||||
|
Loading…
Reference in New Issue
Block a user