lint: fix pylint

This commit is contained in:
csunny 2023-07-24 18:55:57 +08:00
parent ea97bfccc7
commit 632aee3149
18 changed files with 181 additions and 117 deletions

View File

@ -29,7 +29,7 @@ class DBType(Enum):
Postgresql = DbInfo("postgresql") Postgresql = DbInfo("postgresql")
def value(self): def value(self):
return self._value_.name; return self._value_.name
def is_file_db(self): def is_file_db(self):
return self._value_.is_file_db return self._value_.is_file_db

View File

@ -293,7 +293,6 @@ class Database:
result = list(result) result = list(result)
return field_names, result return field_names, result
def run(self, session, command: str, fetch: str = "all") -> List: def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.""" """Execute a SQL command and return a string representing the results."""
print("SQL:" + command) print("SQL:" + command)

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
class BaseConnect( ABC): class BaseConnect(ABC):
def get_connect(self, db_name: str): def get_connect(self, db_name: str):
pass pass

View File

@ -1,5 +1,6 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class DBConfig(BaseModel): class DBConfig(BaseModel):
db_type: str db_type: str
db_name: str db_name: str
@ -8,4 +9,4 @@ class DBConfig(BaseModel):
db_port: int = 0 db_port: int = 0
db_user: str = "" db_user: str = ""
db_pwd: str = "" db_pwd: str = ""
comment: str = "" comment: str = ""

View File

@ -26,7 +26,16 @@ class DuckdbConnectConfig:
) )
self.connect.execute("CREATE SEQUENCE seq_id START 1;") self.connect.execute("CREATE SEQUENCE seq_id START 1;")
def add_url_db(self, db_name, db_type, db_host: str, db_port: int, db_user: str, db_pwd: str, comment: str = ""): def add_url_db(
self,
db_name,
db_type,
db_host: str,
db_port: int,
db_user: str,
db_pwd: str,
comment: str = "",
):
try: try:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute(
@ -41,7 +50,7 @@ class DuckdbConnectConfig:
def get_file_db_name(self, path): def get_file_db_name(self, path):
try: try:
conn = duckdb.connect(path) conn = duckdb.connect(path)
result = conn.execute('SELECT current_database()').fetchone()[0] result = conn.execute("SELECT current_database()").fetchone()[0]
return result return result
except Exception as e: except Exception as e:
raise "Unusable duckdb database path:" + path raise "Unusable duckdb database path:" + path
@ -60,9 +69,7 @@ class DuckdbConnectConfig:
def delete_db(self, db_name): def delete_db(self, db_name):
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute("DELETE FROM connect_config where db_name=?", [db_name])
"DELETE FROM connect_config where db_name=?", [db_name]
)
cursor.commit() cursor.commit()
return True return True
@ -87,9 +94,7 @@ class DuckdbConnectConfig:
def get_db_list(self): def get_db_list(self):
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute( cursor.execute("SELECT db_name, db_type, comment FROM connect_config ")
"SELECT db_name, db_type, comment FROM connect_config "
)
fields = [field[0] for field in cursor.description] fields = [field[0] for field in cursor.description]
data = [] data = []
@ -102,15 +107,12 @@ class DuckdbConnectConfig:
return [] return []
def get_db_names(self): def get_db_names(self):
if os.path.isfile(duckdb_path): if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor() cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute( cursor.execute("SELECT db_name FROM connect_config ")
"SELECT db_name FROM connect_config "
)
data = [] data = []
for row in cursor.fetchall(): for row in cursor.fetchall():
data.append(row[0]) data.append(row[0])
return data return data
return [] return []

View File

@ -16,7 +16,6 @@ CFG = Config()
class ConnectManager: class ConnectManager:
def get_all_subclasses(self, cls): def get_all_subclasses(self, cls):
subclasses = cls.__subclasses__() subclasses = cls.__subclasses__()
for subclass in subclasses: for subclass in subclasses:
@ -41,8 +40,15 @@ class ConnectManager:
if CFG.LOCAL_DB_HOST: if CFG.LOCAL_DB_HOST:
# default mysql # default mysql
if CFG.LOCAL_DB_NAME: if CFG.LOCAL_DB_NAME:
self.storage.add_url_db(CFG.LOCAL_DB_NAME, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT, self.storage.add_url_db(
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "") CFG.LOCAL_DB_NAME,
DBType.Mysql.value(),
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
"",
)
else: else:
# get all default mysql database # get all default mysql database
default_mysql = Database.from_uri( default_mysql = Database.from_uri(
@ -69,32 +75,42 @@ class ConnectManager:
# ) # )
dbs = default_mysql.get_database_list() dbs = default_mysql.get_database_list()
for name in dbs: for name in dbs:
self.storage.add_url_db(name, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT, self.storage.add_url_db(
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "") name,
DBType.Mysql.value(),
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
"",
)
if CFG.LOCAL_DB_PATH: if CFG.LOCAL_DB_PATH:
# default file db is duckdb # default file db is duckdb
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH) db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
if db_name: if db_name:
self.storage.add_file_db(db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH) self.storage.add_file_db(
db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH
)
def get_connect(self, db_name): def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name) db_config = self.storage.get_db_config(db_name)
db_type = DBType.of_db_type(db_config.get('db_type')) db_type = DBType.of_db_type(db_config.get("db_type"))
connect_instance = self.get_cls_by_dbtype(db_type.value()) connect_instance = self.get_cls_by_dbtype(db_type.value())
if db_type.is_file_db(): if db_type.is_file_db():
db_path = db_config.get('db_path') db_path = db_config.get("db_path")
return connect_instance.from_file_path(db_path) return connect_instance.from_file_path(db_path)
else: else:
db_host = db_config.get('db_host') db_host = db_config.get("db_host")
db_port = db_config.get('db_port') db_port = db_config.get("db_port")
db_user = db_config.get('db_user') db_user = db_config.get("db_user")
db_pwd = db_config.get('db_pwd') db_pwd = db_config.get("db_pwd")
return connect_instance.from_uri_db(db_host, db_port, db_user, db_pwd, db_name) return connect_instance.from_uri_db(
db_host, db_port, db_user, db_pwd, db_name
)
def get_db_list(self): def get_db_list(self):
return self.storage.get_db_list() return self.storage.get_db_list()
def get_db_names(self): def get_db_names(self):
return self.storage.get_db_names() return self.storage.get_db_names()
@ -104,7 +120,17 @@ class ConnectManager:
def add_db(self, db_info: DBConfig): def add_db(self, db_info: DBConfig):
db_type = DBType.of_db_type(db_info.db_type) db_type = DBType.of_db_type(db_info.db_type)
if db_type.is_file_db(): if db_type.is_file_db():
self.storage.add_file_db(db_info.db_name, db_info.db_type, db_info.file_path) self.storage.add_file_db(
db_info.db_name, db_info.db_type, db_info.file_path
)
else: else:
self.storage.add_url_db(db_info.db_name, db_info.db_type, db_info.db_host, db_info.db_port, db_info.db_user, db_info.db_pwd, db_info.comment) self.storage.add_url_db(
return True db_info.db_name,
db_info.db_type,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment,
)
return True

View File

@ -11,11 +11,13 @@ from sqlalchemy.ext.declarative import declarative_base
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class DuckDbConnect(RDBMSDatabase): class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData """Connect Duckdb Database fetch MetaData
Args: Args:
Usage: Usage:
""" """
db_type: str = "duckdb" db_type: str = "duckdb"
db_dialect: str = "duckdb" db_dialect: str = "duckdb"
@ -33,7 +35,7 @@ class DuckDbConnect(RDBMSDatabase):
""" """
cursor = self.session.execute(text(_tables_sql)) cursor = self.session.execute(text(_tables_sql))
tables_results = cursor.fetchall() tables_results = cursor.fetchall()
results =[] results = []
for row in tables_results: for row in tables_results:
table_name = row[0] table_name = row[0]
_sql = f""" _sql = f"""
@ -49,11 +51,18 @@ class DuckDbConnect(RDBMSDatabase):
results.append(f"{table_name}({','.join(table_colums)});") results.append(f"{table_name}({','.join(table_colums)});")
return results return results
if __name__ == "__main__": if __name__ == "__main__":
engine = create_engine('duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db') engine = create_engine(
"duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db"
)
metadata = MetaData(engine) metadata = MetaData(engine)
results = engine.connect().execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() results = (
engine.connect()
.execute("SELECT name FROM sqlite_master WHERE type='table'")
.fetchall()
)
print(str(results)) print(str(results))

View File

@ -25,14 +25,13 @@ class MSSQLConnect(RDBMSDatabase):
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"] default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"]
def table_simple_info(self) -> Iterable[str]: def table_simple_info(self) -> Iterable[str]:
_tables_sql = f""" _tables_sql = f"""
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'
""" """
cursor = self.session.execute(text(_tables_sql)) cursor = self.session.execute(text(_tables_sql))
tables_results = cursor.fetchall() tables_results = cursor.fetchall()
results =[] results = []
for row in tables_results: for row in tables_results:
table_name = row[0] table_name = row[0]
_sql = f""" _sql = f"""

View File

@ -36,19 +36,20 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
class RDBMSDatabase(BaseConnect): class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database.""" """SQLAlchemy wrapper around a database."""
db_type: str = None db_type: str = None
def __init__( def __init__(
self, self,
engine, engine,
schema: Optional[str] = None, schema: Optional[str] = None,
metadata: Optional[MetaData] = None, metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None, ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3, sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False, indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None, custom_table_info: Optional[dict] = None,
view_support: bool = False, view_support: bool = False,
): ):
"""Create engine from database URI.""" """Create engine from database URI."""
self._engine = engine self._engine = engine
@ -88,36 +89,35 @@ class RDBMSDatabase(BaseConnect):
) )
) )
@classmethod @classmethod
def from_uri_db( def from_uri_db(
cls, cls,
host: str, host: str,
port: int, port: int,
user: str, user: str,
pwd: str, pwd: str,
db_name: str, db_name: str,
engine_args: Optional[dict] = None, engine_args: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> RDBMSDatabase: ) -> RDBMSDatabase:
db_url: str = ( db_url: str = (
cls.driver cls.driver
+ "://" + "://"
+ CFG.LOCAL_DB_USER + CFG.LOCAL_DB_USER
+ ":" + ":"
+ CFG.LOCAL_DB_PASSWORD + CFG.LOCAL_DB_PASSWORD
+ "@" + "@"
+ CFG.LOCAL_DB_HOST + CFG.LOCAL_DB_HOST
+ ":" + ":"
+ str(CFG.LOCAL_DB_PORT) + str(CFG.LOCAL_DB_PORT)
+ "/" + "/"
+ db_name + db_name
) )
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)
@classmethod @classmethod
def from_uri( def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase: ) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI.""" """Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {} _engine_args = engine_args or {}
@ -141,7 +141,6 @@ class RDBMSDatabase(BaseConnect):
def get_session(self): def get_session(self):
session = self._db_sessions() session = self._db_sessions()
return session return session
def get_current_db_name(self) -> str: def get_current_db_name(self) -> str:
@ -181,7 +180,7 @@ class RDBMSDatabase(BaseConnect):
tbl tbl
for tbl in self._metadata.sorted_tables for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names) if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
] ]
tables = [] tables = []
@ -194,7 +193,7 @@ class RDBMSDatabase(BaseConnect):
create_table = str(CreateTable(table).compile(self._engine)) create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}" table_info = f"{create_table.rstrip()}"
has_extra_info = ( has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info self._indexes_in_table_info or self._sample_rows_in_table_info
) )
if has_extra_info: if has_extra_info:
table_info += "\n\n/*" table_info += "\n\n/*"
@ -413,9 +412,9 @@ class RDBMSDatabase(BaseConnect):
set_idx = parts.index("set") set_idx = parts.index("set")
where_idx = parts.index("where") where_idx = parts.index("where")
# 截取 `set` 子句中的字段名 # 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split("=")[0].strip() set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句 # 截取 `where` 之后的条件语句
where_clause = " ".join(parts[where_idx + 1:]) where_clause = " ".join(parts[where_idx + 1 :])
# 返回一个select语句它选择更新的数据 # 返回一个select语句它选择更新的数据
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}" return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else: else:

View File

@ -23,7 +23,7 @@ from pilot.openapi.api_v1.api_view_model import (
Result, Result,
ConversationVo, ConversationVo,
MessageVo, MessageVo,
ChatSceneVo ChatSceneVo,
) )
from pilot.connections.db_conn_info import DBConfig from pilot.connections.db_conn_info import DBConfig
from pilot.configs.config import Config from pilot.configs.config import Config
@ -96,16 +96,16 @@ def knowledge_list():
return params return params
@router.get("/v1/chat/db/list", response_model=Result[DBConfig]) @router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def dialogue_list(): async def dialogue_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@router.post("/v1/chat/db/add", response_model=Result[bool]) @router.post("/v1/chat/db/add", response_model=Result[bool])
async def dialogue_list(db_config: DBConfig = Body() ): async def dialogue_list(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config)) return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
@router.post("/v1/chat/db/delete", response_model=Result[bool]) @router.post("/v1/chat/db/delete", response_model=Result[bool])
async def dialogue_list(db_name: str = None): async def dialogue_list(db_name: str = None):
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name)) return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
@ -115,6 +115,7 @@ async def dialogue_list(db_name: str = None):
async def db_support_types(): async def db_support_types():
return Result[str].succ(["mysql", "mssql", "duckdb"]) return Result[str].succ(["mysql", "mssql", "duckdb"])
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(user_id: str = None): async def dialogue_list(user_id: str = None):
dialogues: List = [] dialogues: List = []

View File

@ -3,7 +3,15 @@ from typing import List
class Scene: class Scene:
def __init__(self, code, name, describe, param_types: List = [], is_inner: bool = False, show_disable=False): def __init__(
self,
code,
name,
describe,
param_types: List = [],
is_inner: bool = False,
show_disable=False,
):
self.code = code self.code = code
self.name = name self.name = name
self.describe = describe self.describe = describe
@ -31,7 +39,7 @@ class ChatScene(Enum):
"Use tools through dialogue to accomplish your goals.", "Use tools through dialogue to accomplish your goals.",
["Plugin Select"], ["Plugin Select"],
False, False,
True True,
) )
ChatDefaultKnowledge = Scene( ChatDefaultKnowledge = Scene(
"chat_default_knowledge", "chat_default_knowledge",
@ -87,4 +95,4 @@ class ChatScene(Enum):
return self._value_.param_types return self._value_.param_types
def show_disable(self): def show_disable(self):
return self._value_.show_disable return self._value_.show_disable

View File

@ -19,7 +19,7 @@ from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import ( from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData, ChartData,
ReportData, ReportData,
ValueItem ValueItem,
) )
CFG = Config() CFG = Config()
@ -65,7 +65,9 @@ class ChatDashboard(BaseChat):
client = DBSummaryClient() client = DBSummaryClient()
try: try:
table_infos = client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) table_infos = client.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
print("dashboard vector find tables:{}", table_infos) print("dashboard vector find tables:{}", table_infos)
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
@ -74,7 +76,7 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input, "input": self.current_user_input,
"dialect": self.database.dialect, "dialect": self.database.dialect,
"table_info": self.database.table_simple_info(), "table_info": self.database.table_simple_info(),
"supported_chat_type": self.dashboard_template['supported_chart_type'] "supported_chat_type": self.dashboard_template["supported_chart_type"]
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
} }
@ -85,7 +87,9 @@ class ChatDashboard(BaseChat):
chart_datas: List[ChartData] = [] chart_datas: List[ChartData] = []
for chart_item in prompt_response: for chart_item in prompt_response:
try: try:
field_names, datas = self.database.query_ex(self.db_connect, chart_item.sql) field_names, datas = self.database.query_ex(
self.db_connect, chart_item.sql
)
values: List[ValueItem] = [] values: List[ValueItem] = []
data_map = {} data_map = {}
field_map = {} field_map = {}
@ -96,28 +100,45 @@ class ChatDashboard(BaseChat):
if not data_map[field_name]: if not data_map[field_name]:
field_map.update({f"{field_name}": False}) field_map.update({f"{field_name}": False})
else: else:
field_map.update({f"{field_name}": all( field_map.update(
isinstance(item, (int, float, Decimal)) for item in data_map[field_name])}) {
f"{field_name}": all(
isinstance(item, (int, float, Decimal))
for item in data_map[field_name]
)
}
)
for field_name in field_names[1:]: for field_name in field_names[1:]:
if not field_map[field_name]: if not field_map[field_name]:
print("more than 2 non-numeric column") print("more than 2 non-numeric column")
else: else:
for data in datas: for data in datas:
value_item = ValueItem(name=data[0], type=field_name, value_item = ValueItem(
value=data[field_names.index(field_name)]) name=data[0],
type=field_name,
value=data[field_names.index(field_name)],
)
values.append(value_item) values.append(value_item)
chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()), chart_datas.append(
chart_name=chart_item.title, ChartData(
chart_type=chart_item.showcase, chart_uid=str(uuid.uuid1()),
chart_desc=chart_item.thoughts, chart_name=chart_item.title,
chart_sql=chart_item.sql, chart_type=chart_item.showcase,
column_name=field_names, chart_desc=chart_item.thoughts,
values=values)) chart_sql=chart_item.sql,
column_name=field_names,
values=values,
)
)
except Exception as e: except Exception as e:
# TODO 修复流程 # TODO 修复流程
print(str(e)) print(str(e))
return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None, return ReportData(
charts=chart_datas) conv_uid=self.chat_session_id,
template_name=self.report_name,
template_introduce=None,
charts=chart_datas,
)

View File

@ -3,16 +3,15 @@ from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
class ValueItem(BaseModel): class ValueItem(BaseModel):
name: str name: str
type: str = None type: str = None
value: float value: float
def dict(self, *args, **kwargs): def dict(self, *args, **kwargs):
return { return {"name": self.name, "type": self.type, "value": self.value}
"name": self.name,
"type": self.type,
"value": self.value
}
class ChartData(BaseModel): class ChartData(BaseModel):
chart_uid: str chart_uid: str
@ -32,10 +31,11 @@ class ChartData(BaseModel):
"chart_desc": self.chart_desc, "chart_desc": self.chart_desc,
"chart_sql": self.chart_sql, "chart_sql": self.chart_sql,
"column_name": [str(item) for item in self.column_name], "column_name": [str(item) for item in self.column_name],
"values": [value.dict() for value in self.values], "values": [value.dict() for value in self.values],
"style": self.style "style": self.style,
} }
class ReportData(BaseModel): class ReportData(BaseModel):
conv_uid: str conv_uid: str
template_name: str template_name: str
@ -47,5 +47,5 @@ class ReportData(BaseModel):
"conv_uid": self.conv_uid, "conv_uid": self.conv_uid,
"template_name": self.template_name, "template_name": self.template_name,
"template_introduce": self.template_introduce, "template_introduce": self.template_introduce,
"charts": [chart.dict() for chart in self.charts] "charts": [chart.dict() for chart in self.charts],
} }

View File

@ -52,5 +52,3 @@ prompt = PromptTemplate(
), ),
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -44,16 +44,18 @@ class ChatWithDbAutoExecute(BaseChat):
raise ValueError("Could not import DBSummaryClient. ") raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient() client = DBSummaryClient()
try: try:
table_infos = client.get_db_summary(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info() table_infos = self.database.table_simple_info()
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,
"top_k": str(self.top_k), "top_k": str(self.top_k),
"dialect": self.database.dialect, "dialect": self.database.dialect,
"table_info": table_infos "table_info": table_infos,
} }
return input_values return input_values

View File

@ -48,8 +48,9 @@ class ChatWithDbQA(BaseChat):
if self.db_name: if self.db_name:
client = DBSummaryClient() client = DBSummaryClient()
try: try:
table_infos = client.get_db_summary(dbname=self.db_name, query=self.current_user_input, table_infos = client.get_db_summary(
topk=self.top_k) dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info() table_infos = self.database.table_simple_info()

View File

@ -138,7 +138,6 @@ def get_simlar(q):
return "\n".join(contents) return "\n".join(contents)
def plugins_select_info(): def plugins_select_info():
plugins_infos: dict = {} plugins_infos: dict = {}
for plugin in CFG.plugins: for plugin in CFG.plugins:

View File

@ -58,7 +58,6 @@ class MysqlSummary(DBSummary):
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name) self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(), users=self.db.get_users(),
grant=self.db.get_grants(), grant=self.db.get_grants(),