feat(db-g[t): black param change

This commit is contained in:
yhjun1026 2023-08-01 18:11:59 +08:00
parent c6ee16ab82
commit c519203cd6
11 changed files with 46 additions and 27 deletions

View File

@ -29,4 +29,4 @@ jobs:
pip install -U black isort
- name: check the code lint
run: |
black .
black . --check

View File

@ -131,7 +131,6 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
# Generic plugins
plugins_path_path = Path(PLUGINS_DIR)
for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug):
for module in moduleList:

View File

@ -11,7 +11,7 @@ class DBConfig(BaseModel):
db_pwd: str = ""
comment: str = ""
class DbTypeInfo(BaseModel):
db_type:str
is_file_db: bool = False
class DbTypeInfo(BaseModel):
db_type: str
is_file_db: bool = False

View File

@ -47,7 +47,8 @@ class DuckdbConnectConfig:
except Exception as e:
print("add db connect info error1" + str(e))
def update_db_info(self,
def update_db_info(
self,
db_name,
db_type,
db_path: str = "",
@ -55,15 +56,20 @@ class DuckdbConnectConfig:
db_port: int = 0,
db_user: str = "",
db_pwd: str = "",
comment: str = "" ):
comment: str = "",
):
old_db_conf = self.get_db_config(db_name)
if old_db_conf:
try:
cursor = self.connect.cursor()
if not db_path:
cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'")
cursor.execute(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
)
else:
cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'")
cursor.execute(
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
)
cursor.commit()
self.connect.commit()
except Exception as e:
@ -79,7 +85,6 @@ class DuckdbConnectConfig:
except Exception as e:
raise "Unusable duckdb database path:" + path
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
try:
cursor = self.connect.cursor()

View File

@ -122,14 +122,16 @@ class ConnectManager:
return self.storage.delete_db(db_name)
def edit_db(self, db_info: DBConfig):
return self.storage.update_db_info(db_info.db_name,
db_info.db_type,
db_info.file_path,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment)
return self.storage.update_db_info(
db_info.db_name,
db_info.db_type,
db_info.file_path,
db_info.db_host,
db_info.db_port,
db_info.db_user,
db_info.db_pwd,
db_info.comment,
)
def add_db(self, db_info: DBConfig):
print(f"add_db:{db_info.__dict__}")
@ -140,7 +142,6 @@ class ConnectManager:
db_info.db_name, db_info.db_type, db_info.file_path
)
else:
self.storage.add_url_db(
db_info.db_name,
db_info.db_type,
@ -151,7 +152,11 @@ class ConnectManager:
db_info.comment,
)
# async embedding
thread = threading.Thread(target=self.db_summary_client.db_summary_embedding(db_info.db_name, db_info.db_type))
thread = threading.Thread(
target=self.db_summary_client.db_summary_embedding(
db_info.db_name, db_info.db_type
)
)
thread.start()
except Exception as e:
raise ValueError("Add db connect info error" + str(e))

View File

@ -30,7 +30,11 @@ class DuckDbConnect(RDBMSDatabase):
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
def get_users(self):
cursor = self.session.execute(text(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"))
cursor = self.session.execute(
text(
f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"
)
)
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
@ -40,6 +44,7 @@ class DuckDbConnect(RDBMSDatabase):
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"

View File

@ -121,7 +121,9 @@ async def db_support_types():
support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb]
db_type_infos = []
for type in support_types:
db_type_infos.append(DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db()))
db_type_infos.append(
DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())
)
return Result[DbTypeInfo].succ(db_type_infos)
@ -169,7 +171,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
):
conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo)

View File

@ -5,10 +5,12 @@ import shutil
import argparse
import sys
import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
# from pilot.configs.model_config import (
# DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -7,6 +7,7 @@ import sys
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config
# from pilot.configs.model_config import (
# DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -58,8 +58,8 @@ class DBSummaryClient:
)
embedding.source_embedding()
for (
table_name,
table_summary,
table_name,
table_summary,
) in db_summary_client.get_table_summary().items():
table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts",

View File

@ -5,12 +5,13 @@ from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, Inde
CFG = Config()
class RdbmsSummary(DBSummary):
"""Get mysql summary template."""
def __init__(self, name, type):
self.name = name
self.type = type
self.type = type
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.tables = {}
self.tables_info = []
@ -177,4 +178,3 @@ class RdbmsIndexSummary(IndexSummary):
return self.summery_template.format(
name=self.name, bind_fields=self.bind_fields
)