feat(db-g[t): black param change

This commit is contained in:
yhjun1026 2023-08-01 18:11:59 +08:00
parent c6ee16ab82
commit c519203cd6
11 changed files with 46 additions and 27 deletions

View File

@ -29,4 +29,4 @@ jobs:
pip install -U black isort pip install -U black isort
- name: check the code lint - name: check the code lint
run: | run: |
black . black . --check

View File

@ -131,7 +131,6 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
# Generic plugins # Generic plugins
plugins_path_path = Path(PLUGINS_DIR) plugins_path_path = Path(PLUGINS_DIR)
for plugin in plugins_path_path.glob("*.zip"): for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug): if moduleList := inspect_zip_for_modules(str(plugin), debug):
for module in moduleList: for module in moduleList:

View File

@ -11,7 +11,7 @@ class DBConfig(BaseModel):
db_pwd: str = "" db_pwd: str = ""
comment: str = "" comment: str = ""
class DbTypeInfo(BaseModel):
db_type:str
is_file_db: bool = False
class DbTypeInfo(BaseModel):
db_type: str
is_file_db: bool = False

View File

@ -47,7 +47,8 @@ class DuckdbConnectConfig:
except Exception as e: except Exception as e:
print("add db connect info error1" + str(e)) print("add db connect info error1" + str(e))
def update_db_info(self, def update_db_info(
self,
db_name, db_name,
db_type, db_type,
db_path: str = "", db_path: str = "",
@ -55,15 +56,20 @@ class DuckdbConnectConfig:
db_port: int = 0, db_port: int = 0,
db_user: str = "", db_user: str = "",
db_pwd: str = "", db_pwd: str = "",
comment: str = "" ): comment: str = "",
):
old_db_conf = self.get_db_config(db_name) old_db_conf = self.get_db_config(db_name)
if old_db_conf: if old_db_conf:
try: try:
cursor = self.connect.cursor() cursor = self.connect.cursor()
if not db_path: if not db_path:
cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'") cursor.execute(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
)
else: else:
cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'") cursor.execute(
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
)
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
except Exception as e: except Exception as e:
@ -79,7 +85,6 @@ class DuckdbConnectConfig:
except Exception as e: except Exception as e:
raise "Unusable duckdb database path:" + path raise "Unusable duckdb database path:" + path
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
try: try:
cursor = self.connect.cursor() cursor = self.connect.cursor()

View File

@ -122,14 +122,16 @@ class ConnectManager:
return self.storage.delete_db(db_name) return self.storage.delete_db(db_name)
def edit_db(self, db_info: DBConfig): def edit_db(self, db_info: DBConfig):
return self.storage.update_db_info(db_info.db_name, return self.storage.update_db_info(
db_info.db_type, db_info.db_name,
db_info.file_path, db_info.db_type,
db_info.db_host, db_info.file_path,
db_info.db_port, db_info.db_host,
db_info.db_user, db_info.db_port,
db_info.db_pwd, db_info.db_user,
db_info.comment) db_info.db_pwd,
db_info.comment,
)
def add_db(self, db_info: DBConfig): def add_db(self, db_info: DBConfig):
print(f"add_db:{db_info.__dict__}") print(f"add_db:{db_info.__dict__}")
@ -140,7 +142,6 @@ class ConnectManager:
db_info.db_name, db_info.db_type, db_info.file_path db_info.db_name, db_info.db_type, db_info.file_path
) )
else: else:
self.storage.add_url_db( self.storage.add_url_db(
db_info.db_name, db_info.db_name,
db_info.db_type, db_info.db_type,
@ -151,7 +152,11 @@ class ConnectManager:
db_info.comment, db_info.comment,
) )
# async embedding # async embedding
thread = threading.Thread(target=self.db_summary_client.db_summary_embedding(db_info.db_name, db_info.db_type)) thread = threading.Thread(
target=self.db_summary_client.db_summary_embedding(
db_info.db_name, db_info.db_type
)
)
thread.start() thread.start()
except Exception as e: except Exception as e:
raise ValueError("Add db connect info error" + str(e)) raise ValueError("Add db connect info error" + str(e))

View File

@ -30,7 +30,11 @@ class DuckDbConnect(RDBMSDatabase):
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs) return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
def get_users(self): def get_users(self):
cursor = self.session.execute(text(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';")) cursor = self.session.execute(
text(
f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"
)
)
users = cursor.fetchall() users = cursor.fetchall()
return [(user[0], user[1]) for user in users] return [(user[0], user[1]) for user in users]
@ -40,6 +44,7 @@ class DuckDbConnect(RDBMSDatabase):
def get_collation(self): def get_collation(self):
"""Get collation.""" """Get collation."""
return "UTF-8" return "UTF-8"
def get_charset(self): def get_charset(self):
return "UTF-8" return "UTF-8"

View File

@ -121,7 +121,9 @@ async def db_support_types():
support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb] support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb]
db_type_infos = [] db_type_infos = []
for type in support_types: for type in support_types:
db_type_infos.append(DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())) db_type_infos.append(
DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())
)
return Result[DbTypeInfo].succ(db_type_infos) return Result[DbTypeInfo].succ(db_type_infos)
@ -169,7 +171,7 @@ async def dialogue_scenes():
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
async def dialogue_new( async def dialogue_new(
chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None
): ):
conv_vo = __new_conversation(chat_mode, user_id) conv_vo = __new_conversation(chat_mode, user_id)
return Result.succ(conv_vo) return Result.succ(conv_vo)

View File

@ -5,10 +5,12 @@ import shutil
import argparse import argparse
import sys import sys
import logging import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
import signal import signal
from pilot.configs.config import Config from pilot.configs.config import Config
# from pilot.configs.model_config import ( # from pilot.configs.model_config import (
# DATASETS_DIR, # DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH, # KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -7,6 +7,7 @@ import sys
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config from pilot.configs.config import Config
# from pilot.configs.model_config import ( # from pilot.configs.model_config import (
# DATASETS_DIR, # DATASETS_DIR,
# KNOWLEDGE_UPLOAD_ROOT_PATH, # KNOWLEDGE_UPLOAD_ROOT_PATH,

View File

@ -58,8 +58,8 @@ class DBSummaryClient:
) )
embedding.source_embedding() embedding.source_embedding()
for ( for (
table_name, table_name,
table_summary, table_summary,
) in db_summary_client.get_table_summary().items(): ) in db_summary_client.get_table_summary().items():
table_vector_store_config = { table_vector_store_config = {
"vector_store_name": dbname + "_" + table_name + "_ts", "vector_store_name": dbname + "_" + table_name + "_ts",

View File

@ -5,12 +5,13 @@ from pilot.summary.db_summary import DBSummary, TableSummary, FieldSummary, Inde
CFG = Config() CFG = Config()
class RdbmsSummary(DBSummary): class RdbmsSummary(DBSummary):
"""Get mysql summary template.""" """Get mysql summary template."""
def __init__(self, name, type): def __init__(self, name, type):
self.name = name self.name = name
self.type = type self.type = type
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.tables = {} self.tables = {}
self.tables_info = [] self.tables_info = []
@ -177,4 +178,3 @@ class RdbmsIndexSummary(IndexSummary):
return self.summery_template.format( return self.summery_template.format(
name=self.name, bind_fields=self.bind_fields name=self.name, bind_fields=self.bind_fields
) )