diff --git a/pilot/connections/db_conn_info.py b/pilot/connections/db_conn_info.py index 767aac881..e1c633979 100644 --- a/pilot/connections/db_conn_info.py +++ b/pilot/connections/db_conn_info.py @@ -10,3 +10,8 @@ class DBConfig(BaseModel): db_user: str = "" db_pwd: str = "" comment: str = "" + +class DbTypeInfo(BaseModel): + db_type:str + is_file_db: bool = False + diff --git a/pilot/connections/manages/connect_storage_duckdb.py b/pilot/connections/manages/connect_storage_duckdb.py index 9ff291ad2..29542f6d5 100644 --- a/pilot/connections/manages/connect_storage_duckdb.py +++ b/pilot/connections/manages/connect_storage_duckdb.py @@ -47,6 +47,30 @@ class DuckdbConnectConfig: except Exception as e: print("add db connect info error1!" + str(e)) + def update_db_info(self, + db_name, + db_type, + db_path: str = "", + db_host: str = "", + db_port: int = 0, + db_user: str = "", + db_pwd: str = "", + comment: str = "" ): + old_db_conf = self.get_db_config(db_name) + if old_db_conf: + try: + cursor = self.connect.cursor() + if not db_path: + cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'") + else: + cursor.execute(f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'") + cursor.commit() + self.connect.commit() + except Exception as e: + print("edit db connect info error2!" + str(e)) + return True + raise ValueError(f"{db_name} not have config info!") + def get_file_db_name(self, path): try: conn = duckdb.connect(path) @@ -55,12 +79,13 @@ class DuckdbConnectConfig: except Exception as e: raise "Unusable duckdb database path:" + path + def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""): try: cursor = self.connect.cursor() cursor.execute( "INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)", - [db_name, db_type, db_path, "", "", "", "", comment], + [db_name, db_type, db_path, "", 0, "", "", comment], ) cursor.commit() self.connect.commit() @@ -89,12 +114,12 @@ class DuckdbConnectConfig: for i, field in enumerate(fields): row_dict[field] = row_1[i] return row_dict - return {} + return None def get_db_list(self): if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() - cursor.execute("SELECT db_name, db_type, comment FROM connect_config ") + cursor.execute("SELECT * FROM connect_config ") fields = [field[0] for field in cursor.description] data = [] diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index 43a4dce2c..2c94082e7 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -1,3 +1,5 @@ +import threading + from pilot.configs.config import Config from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig from pilot.common.schema import DBType @@ -11,6 +13,7 @@ from pilot.connections.rdbms.base import RDBMSDatabase from pilot.singleton import Singleton from pilot.common.sql_database import Database from pilot.connections.db_conn_info import DBConfig +from pilot.summary.db_summary_client import DBSummaryClient CFG = Config() @@ -34,6 +37,7 @@ class ConnectManager: def __init__(self): self.storage = DuckdbConnectConfig() + self.db_summary_client = DBSummaryClient() self.__load_config_db() def __load_config_db(self): @@ -117,20 +121,39 @@ class ConnectManager: def delete_db(self, db_name: str): return self.storage.delete_db(db_name) + def edit_db(self, db_info: DBConfig): + return self.storage.update_db_info(db_info.db_name, + db_info.db_type, + db_info.file_path, + db_info.db_host, + db_info.db_port, + db_info.db_user, + db_info.db_pwd, + db_info.comment) + def add_db(self, db_info: DBConfig): - db_type = DBType.of_db_type(db_info.db_type) - if db_type.is_file_db(): - self.storage.add_file_db( - db_info.db_name, db_info.db_type, db_info.file_path - ) - else: - self.storage.add_url_db( - db_info.db_name, - db_info.db_type, - db_info.db_host, - db_info.db_port, - db_info.db_user, - db_info.db_pwd, - db_info.comment, - ) + print(f"add_db:{db_info.__dict__}") + try: + db_type = DBType.of_db_type(db_info.db_type) + if db_type.is_file_db(): + self.storage.add_file_db( + db_info.db_name, db_info.db_type, db_info.file_path + ) + else: + + self.storage.add_url_db( + db_info.db_name, + db_info.db_type, + db_info.db_host, + db_info.db_port, + db_info.db_user, + db_info.db_pwd, + db_info.comment, + ) + # async embedding + thread = threading.Thread(target=self.db_summary_client.db_summary_embedding(db_info.db_name, db_info.db_type)) + thread.start() + except Exception as e: + raise ValueError("Add db connect info error!" + str(e)) + return True diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 7d3482dfc..7c483e2e9 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -25,7 +25,7 @@ from pilot.openapi.api_v1.api_view_model import ( MessageVo, ChatSceneVo, ) -from pilot.connections.db_conn_info import DBConfig +from pilot.connections.db_conn_info import DBConfig, DbTypeInfo from pilot.configs.config import Config from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.request.request import KnowledgeSpaceRequest @@ -35,7 +35,7 @@ from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory from pilot.configs.model_config import LOGDIR from pilot.utils import build_logger -from pilot.scene.base_message import BaseMessage +from pilot.common.schema import DBType from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.scene.message import OnceConversation @@ -97,23 +97,32 @@ def knowledge_list(): @router.get("/v1/chat/db/list", response_model=Result[DBConfig]) -async def dialogue_list(): +async def db_connect_list(): return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list()) @router.post("/v1/chat/db/add", response_model=Result[bool]) -async def dialogue_list(db_config: DBConfig = Body()): +async def db_connect_add(db_config: DBConfig = Body()): return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config)) +@router.post("/v1/chat/db/edit", response_model=Result[bool]) +async def db_connect_edit(db_config: DBConfig = Body()): + return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config)) + + @router.post("/v1/chat/db/delete", response_model=Result[bool]) -async def dialogue_list(db_name: str = None): +async def db_connect_delete(db_name: str = None): return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name)) -@router.get("/v1/chat/db/support/type", response_model=Result[str]) +@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo]) async def db_support_types(): - return Result[str].succ(["mysql", "mssql", "duckdb"]) + support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb] + db_type_infos = [] + for type in support_types: + db_type_infos.append(DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())) + return Result[DbTypeInfo].succ(db_type_infos) @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) @@ -160,7 +169,7 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value(), user_id: str = None ): conv_vo = __new_conversation(chat_mode, user_id) return Result.succ(conv_vo) diff --git a/pilot/scene/chat_data/__init__.py b/pilot/scene/chat_data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_data/chat_excel/__init__.py b/pilot/scene/chat_data/chat_excel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/__init__.py b/pilot/scene/chat_data/chat_excel/excel_analyze/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/__init__.py b/pilot/scene/chat_data/chat_excel/excel_learning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 6ee691a88..e284ce1a9 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -8,10 +8,10 @@ from pilot.scene.chat_db.auto_execute.example import sql_data_example CFG = Config() -PROMPT_SCENE_DEFINE = None +PROMPT_SCENE_DEFINE = "You are a SQL expert. " _DEFAULT_TEMPLATE = """ -You are a SQL expert. Given an input question, create a syntactically correct {dialect} sql. +Given an input question, create a syntactically correct {dialect} sql. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. Use as few tables as possible when querying. diff --git a/pilot/server/static/404.html b/pilot/server/static/404.html index 4da41deb4..2dda42bba 100644 --- a/pilot/server/static/404.html +++ b/pilot/server/static/404.html @@ -1 +1 @@ -