lint: fix pylint

This commit is contained in:
csunny 2023-07-24 18:55:57 +08:00
parent ea97bfccc7
commit 632aee3149
18 changed files with 181 additions and 117 deletions

View File

@ -29,7 +29,7 @@ class DBType(Enum):
Postgresql = DbInfo("postgresql")
def value(self):
return self._value_.name;
return self._value_.name
def is_file_db(self):
return self._value_.is_file_db

View File

@ -293,7 +293,6 @@ class Database:
result = list(result)
return field_names, result
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from typing import Any, Iterable, List, Optional
class BaseConnect( ABC):
class BaseConnect(ABC):
def get_connect(self, db_name: str):
pass

View File

@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
class DBConfig(BaseModel):
db_type: str
db_name: str
@ -8,4 +9,4 @@ class DBConfig(BaseModel):
db_port: int = 0
db_user: str = ""
db_pwd: str = ""
comment: str = ""
comment: str = ""

View File

@ -26,7 +26,16 @@ class DuckdbConnectConfig:
)
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
def add_url_db(self, db_name, db_type, db_host: str, db_port: int, db_user: str, db_pwd: str, comment: str = ""):
def add_url_db(
self,
db_name,
db_type,
db_host: str,
db_port: int,
db_user: str,
db_pwd: str,
comment: str = "",
):
try:
cursor = self.connect.cursor()
cursor.execute(
@ -41,7 +50,7 @@ class DuckdbConnectConfig:
def get_file_db_name(self, path):
try:
conn = duckdb.connect(path)
result = conn.execute('SELECT current_database()').fetchone()[0]
result = conn.execute("SELECT current_database()").fetchone()[0]
return result
except Exception as e:
raise "Unusable duckdb database path:" + path
@ -60,9 +69,7 @@ class DuckdbConnectConfig:
def delete_db(self, db_name):
cursor = self.connect.cursor()
cursor.execute(
"DELETE FROM connect_config where db_name=?", [db_name]
)
cursor.execute("DELETE FROM connect_config where db_name=?", [db_name])
cursor.commit()
return True
@ -87,9 +94,7 @@ class DuckdbConnectConfig:
def get_db_list(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute(
"SELECT db_name, db_type, comment FROM connect_config "
)
cursor.execute("SELECT db_name, db_type, comment FROM connect_config ")
fields = [field[0] for field in cursor.description]
data = []
@ -102,15 +107,12 @@ class DuckdbConnectConfig:
return []
def get_db_names(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute(
"SELECT db_name FROM connect_config "
)
cursor.execute("SELECT db_name FROM connect_config ")
data = []
for row in cursor.fetchall():
data.append(row[0])
return data
return []
return []

View File

@ -16,7 +16,6 @@ CFG = Config()
class ConnectManager:
def get_all_subclasses(self, cls):
subclasses = cls.__subclasses__()
for subclass in subclasses:
@ -41,8 +40,15 @@ class ConnectManager:
if CFG.LOCAL_DB_HOST:
# default mysql
if CFG.LOCAL_DB_NAME:
self.storage.add_url_db(CFG.LOCAL_DB_NAME, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
self.storage.add_url_db(
CFG.LOCAL_DB_NAME,
DBType.Mysql.value(),
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
"",
)
else:
# get all default mysql database
default_mysql = Database.from_uri(
@ -69,32 +75,42 @@ class ConnectManager:
# )
dbs = default_mysql.get_database_list()
for name in dbs:
self.storage.add_url_db(name, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
self.storage.add_url_db(
name,
DBType.Mysql.value(),
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
"",
)
if CFG.LOCAL_DB_PATH:
# default file db is duckdb
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
if db_name:
self.storage.add_file_db(db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH)
self.storage.add_file_db(
db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH
)
def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name)
db_type = DBType.of_db_type(db_config.get('db_type'))
db_type = DBType.of_db_type(db_config.get("db_type"))
connect_instance = self.get_cls_by_dbtype(db_type.value())
if db_type.is_file_db():
db_path = db_config.get('db_path')
db_path = db_config.get("db_path")
return connect_instance.from_file_path(db_path)
else:
db_host = db_config.get('db_host')
db_port = db_config.get('db_port')
db_user = db_config.get('db_user')
db_pwd = db_config.get('db_pwd')
return connect_instance.from_uri_db(db_host, db_port, db_user, db_pwd, db_name)
db_host = db_config.get("db_host")
db_port = db_config.get("db_port")
db_user = db_config.get("db_user")
db_pwd = db_config.get("db_pwd")
return connect_instance.from_uri_db(
db_host, db_port, db_user, db_pwd, db_name
)
def get_db_list(self):
return self.storage.get_db_list()
def get_db_names(self):
return self.storage.get_db_names()
@ -104,7 +120,17 @@ class ConnectManager:
def add_db(self, db_info: DBConfig):
db_type = DBType.of_db_type(db_info.db_type)
if db_type.is_file_db():
self.storage.add_file_db(db_info.db_name, db_info.db_type, db_info.file_path)
self.storage.add_file_db(
db_info.db_name, db_info.db_type, db_info.file_path
)
else:
self.storage.add_url_db(db_info.db_name, db_info.db_type, db_info.db_host, db_info.db_port, db_info.db_user, db_info.db_pwd, db_info.comment)
return True
self.storage.add_url_db(
db_info.db_name,
db_info.db_type,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment,
)
return True

View File

@ -11,11 +11,13 @@ from sqlalchemy.ext.declarative import declarative_base
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData
Args:
Usage:
"""
db_type: str = "duckdb"
db_dialect: str = "duckdb"
@ -33,7 +35,7 @@ class DuckDbConnect(RDBMSDatabase):
"""
cursor = self.session.execute(text(_tables_sql))
tables_results = cursor.fetchall()
results =[]
results = []
for row in tables_results:
table_name = row[0]
_sql = f"""
@ -49,11 +51,18 @@ class DuckDbConnect(RDBMSDatabase):
results.append(f"{table_name}({','.join(table_colums)});")
return results
if __name__ == "__main__":
engine = create_engine('duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db')
engine = create_engine(
"duckdb:////Users/tuyang.yhj/Code/PycharmProjects/DB-GPT/pilot/mock_datas/db-gpt-test.db"
)
metadata = MetaData(engine)
results = engine.connect().execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
results = (
engine.connect()
.execute("SELECT name FROM sqlite_master WHERE type='table'")
.fetchall()
)
print(str(results))

View File

@ -25,14 +25,13 @@ class MSSQLConnect(RDBMSDatabase):
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"]
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'
"""
cursor = self.session.execute(text(_tables_sql))
tables_results = cursor.fetchall()
results =[]
results = []
for row in tables_results:
table_name = row[0]
_sql = f"""

View File

@ -36,19 +36,20 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database."""
db_type: str = None
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
@ -88,36 +89,35 @@ class RDBMSDatabase(BaseConnect):
)
)
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ "/"
+ db_name
cls.driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ "/"
+ db_name
)
return cls.from_uri(db_url, engine_args, **kwargs)
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
@ -141,7 +141,6 @@ class RDBMSDatabase(BaseConnect):
def get_session(self):
session = self._db_sessions()
return session
def get_current_db_name(self) -> str:
@ -181,7 +180,7 @@ class RDBMSDatabase(BaseConnect):
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
@ -194,7 +193,7 @@ class RDBMSDatabase(BaseConnect):
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
@ -413,9 +412,9 @@ class RDBMSDatabase(BaseConnect):
set_idx = parts.index("set")
where_idx = parts.index("where")
# 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split("=")[0].strip()
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句
where_clause = " ".join(parts[where_idx + 1:])
where_clause = " ".join(parts[where_idx + 1 :])
# 返回一个select语句它选择更新的数据
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else:

View File

@ -23,7 +23,7 @@ from pilot.openapi.api_v1.api_view_model import (
Result,
ConversationVo,
MessageVo,
ChatSceneVo
ChatSceneVo,
)
from pilot.connections.db_conn_info import DBConfig
from pilot.configs.config import Config
@ -96,16 +96,16 @@ def knowledge_list():
return params
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
async def dialogue_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@router.post("/v1/chat/db/add", response_model=Result[bool])
async def dialogue_list(db_config: DBConfig = Body() ):
async def dialogue_list(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
@router.post("/v1/chat/db/delete", response_model=Result[bool])
async def dialogue_list(db_name: str = None):
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
@ -115,6 +115,7 @@ async def dialogue_list(db_name: str = None):
async def db_support_types():
return Result[str].succ(["mysql", "mssql", "duckdb"])
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list(user_id: str = None):
dialogues: List = []

View File

@ -3,7 +3,15 @@ from typing import List
class Scene:
def __init__(self, code, name, describe, param_types: List = [], is_inner: bool = False, show_disable=False):
def __init__(
self,
code,
name,
describe,
param_types: List = [],
is_inner: bool = False,
show_disable=False,
):
self.code = code
self.name = name
self.describe = describe
@ -31,7 +39,7 @@ class ChatScene(Enum):
"Use tools through dialogue to accomplish your goals.",
["Plugin Select"],
False,
True
True,
)
ChatDefaultKnowledge = Scene(
"chat_default_knowledge",
@ -87,4 +95,4 @@ class ChatScene(Enum):
return self._value_.param_types
def show_disable(self):
return self._value_.show_disable
return self._value_.show_disable

View File

@ -19,7 +19,7 @@ from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
ValueItem
ValueItem,
)
CFG = Config()
@ -65,7 +65,9 @@ class ChatDashboard(BaseChat):
client = DBSummaryClient()
try:
table_infos = client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
table_infos = client.get_similar_tables(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
print("dashboard vector find tables:{}", table_infos)
except Exception as e:
print("db summary find error!" + str(e))
@ -74,7 +76,7 @@ class ChatDashboard(BaseChat):
"input": self.current_user_input,
"dialect": self.database.dialect,
"table_info": self.database.table_simple_info(),
"supported_chat_type": self.dashboard_template['supported_chart_type']
"supported_chat_type": self.dashboard_template["supported_chart_type"]
# "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
}
@ -85,7 +87,9 @@ class ChatDashboard(BaseChat):
chart_datas: List[ChartData] = []
for chart_item in prompt_response:
try:
field_names, datas = self.database.query_ex(self.db_connect, chart_item.sql)
field_names, datas = self.database.query_ex(
self.db_connect, chart_item.sql
)
values: List[ValueItem] = []
data_map = {}
field_map = {}
@ -96,28 +100,45 @@ class ChatDashboard(BaseChat):
if not data_map[field_name]:
field_map.update({f"{field_name}": False})
else:
field_map.update({f"{field_name}": all(
isinstance(item, (int, float, Decimal)) for item in data_map[field_name])})
field_map.update(
{
f"{field_name}": all(
isinstance(item, (int, float, Decimal))
for item in data_map[field_name]
)
}
)
for field_name in field_names[1:]:
if not field_map[field_name]:
print("more than 2 non-numeric column")
else:
for data in datas:
value_item = ValueItem(name=data[0], type=field_name,
value=data[field_names.index(field_name)])
value_item = ValueItem(
name=data[0],
type=field_name,
value=data[field_names.index(field_name)],
)
values.append(value_item)
chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()),
chart_name=chart_item.title,
chart_type=chart_item.showcase,
chart_desc=chart_item.thoughts,
chart_sql=chart_item.sql,
column_name=field_names,
values=values))
chart_datas.append(
ChartData(
chart_uid=str(uuid.uuid1()),
chart_name=chart_item.title,
chart_type=chart_item.showcase,
chart_desc=chart_item.thoughts,
chart_sql=chart_item.sql,
column_name=field_names,
values=values,
)
)
except Exception as e:
# TODO 修复流程
print(str(e))
return ReportData(conv_uid=self.chat_session_id, template_name=self.report_name, template_introduce=None,
charts=chart_datas)
return ReportData(
conv_uid=self.chat_session_id,
template_name=self.report_name,
template_introduce=None,
charts=chart_datas,
)

View File

@ -3,16 +3,15 @@ from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
class ValueItem(BaseModel):
name: str
type: str = None
value: float
def dict(self, *args, **kwargs):
return {
"name": self.name,
"type": self.type,
"value": self.value
}
return {"name": self.name, "type": self.type, "value": self.value}
class ChartData(BaseModel):
chart_uid: str
@ -32,10 +31,11 @@ class ChartData(BaseModel):
"chart_desc": self.chart_desc,
"chart_sql": self.chart_sql,
"column_name": [str(item) for item in self.column_name],
"values": [value.dict() for value in self.values],
"style": self.style
"values": [value.dict() for value in self.values],
"style": self.style,
}
class ReportData(BaseModel):
conv_uid: str
template_name: str
@ -47,5 +47,5 @@ class ReportData(BaseModel):
"conv_uid": self.conv_uid,
"template_name": self.template_name,
"template_introduce": self.template_introduce,
"charts": [chart.dict() for chart in self.charts]
}
"charts": [chart.dict() for chart in self.charts],
}

View File

@ -52,5 +52,3 @@ prompt = PromptTemplate(
),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -44,16 +44,18 @@ class ChatWithDbAutoExecute(BaseChat):
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient()
try:
table_infos = client.get_db_summary(dbname=self.db_name, query=self.current_user_input, topk=self.top_k)
table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
except Exception as e:
print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info()
table_infos = self.database.table_simple_info()
input_values = {
"input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": table_infos
"table_info": table_infos,
}
return input_values

View File

@ -48,8 +48,9 @@ class ChatWithDbQA(BaseChat):
if self.db_name:
client = DBSummaryClient()
try:
table_infos = client.get_db_summary(dbname=self.db_name, query=self.current_user_input,
topk=self.top_k)
table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=self.top_k
)
except Exception as e:
print("db summary find error!" + str(e))
table_infos = self.database.table_simple_info()

View File

@ -138,7 +138,6 @@ def get_simlar(q):
return "\n".join(contents)
def plugins_select_info():
plugins_infos: dict = {}
for plugin in CFG.plugins:

View File

@ -58,7 +58,6 @@ class MysqlSummary(DBSummary):
self.db = CFG.LOCAL_DB_MANAGE.get_connect(name)
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(),
grant=self.db.get_grants(),