diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py
index e22224399..3ee5f4ac2 100644
--- a/pilot/common/plugins.py
+++ b/pilot/common/plugins.py
@@ -131,9 +131,6 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
# Generic plugins
plugins_path_path = Path(PLUGINS_DIR)
- logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
- logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
-
for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug):
for module in moduleList:
diff --git a/pilot/connections/db_conn_info.py b/pilot/connections/db_conn_info.py
index 767aac881..6430d0552 100644
--- a/pilot/connections/db_conn_info.py
+++ b/pilot/connections/db_conn_info.py
@@ -10,3 +10,8 @@ class DBConfig(BaseModel):
db_user: str = ""
db_pwd: str = ""
comment: str = ""
+
+
+class DbTypeInfo(BaseModel):
+ db_type: str
+ is_file_db: bool = False
diff --git a/pilot/connections/manages/connect_storage_duckdb.py b/pilot/connections/manages/connect_storage_duckdb.py
index 9ff291ad2..acacaffd7 100644
--- a/pilot/connections/manages/connect_storage_duckdb.py
+++ b/pilot/connections/manages/connect_storage_duckdb.py
@@ -47,6 +47,36 @@ class DuckdbConnectConfig:
except Exception as e:
print("add db connect info error1!" + str(e))
+ def update_db_info(
+ self,
+ db_name,
+ db_type,
+ db_path: str = "",
+ db_host: str = "",
+ db_port: int = 0,
+ db_user: str = "",
+ db_pwd: str = "",
+ comment: str = "",
+ ):
+ old_db_conf = self.get_db_config(db_name)
+ if old_db_conf:
+ try:
+ cursor = self.connect.cursor()
+ if not db_path:
+ cursor.execute(
+ f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
+ )
+ else:
+ cursor.execute(
+ f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
+ )
+ cursor.commit()
+ self.connect.commit()
+ except Exception as e:
+ print("edit db connect info error2!" + str(e))
+ return True
+ raise ValueError(f"{db_name} not have config info!")
+
def get_file_db_name(self, path):
try:
conn = duckdb.connect(path)
@@ -60,7 +90,7 @@ class DuckdbConnectConfig:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
- [db_name, db_type, db_path, "", "", "", "", comment],
+ [db_name, db_type, db_path, "", 0, "", "", comment],
)
cursor.commit()
self.connect.commit()
@@ -89,12 +119,12 @@ class DuckdbConnectConfig:
for i, field in enumerate(fields):
row_dict[field] = row_1[i]
return row_dict
- return {}
+ return None
def get_db_list(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
- cursor.execute("SELECT db_name, db_type, comment FROM connect_config ")
+ cursor.execute("SELECT * FROM connect_config ")
fields = [field[0] for field in cursor.description]
data = []
diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py
index 291127127..38d7dd9f0 100644
--- a/pilot/connections/manages/connection_manager.py
+++ b/pilot/connections/manages/connection_manager.py
@@ -1,3 +1,5 @@
+import threading
+
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
@@ -7,10 +9,11 @@ from pilot.connections.base import BaseConnect
from pilot.connections.rdbms.conn_mysql import MySQLConnect
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
-from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
+from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.connections.db_conn_info import DBConfig
+from pilot.summary.db_summary_client import DBSummaryClient
CFG = Config()
@@ -34,6 +37,7 @@ class ConnectManager:
def __init__(self):
self.storage = DuckdbConnectConfig()
+ self.db_summary_client = DBSummaryClient()
self.__load_config_db()
def __load_config_db(self):
@@ -117,20 +121,44 @@ class ConnectManager:
def delete_db(self, db_name: str):
return self.storage.delete_db(db_name)
+ def edit_db(self, db_info: DBConfig):
+ return self.storage.update_db_info(
+ db_info.db_name,
+ db_info.db_type,
+ db_info.file_path,
+ db_info.db_host,
+ db_info.db_port,
+ db_info.db_user,
+ db_info.db_pwd,
+ db_info.comment,
+ )
+
def add_db(self, db_info: DBConfig):
- db_type = DBType.of_db_type(db_info.db_type)
- if db_type.is_file_db():
- self.storage.add_file_db(
- db_info.db_name, db_info.db_type, db_info.file_path
- )
- else:
- self.storage.add_url_db(
- db_info.db_name,
- db_info.db_type,
- db_info.db_host,
- db_info.db_port,
- db_info.db_user,
- db_info.db_pwd,
- db_info.comment,
+ print(f"add_db:{db_info.__dict__}")
+ try:
+ db_type = DBType.of_db_type(db_info.db_type)
+ if db_type.is_file_db():
+ self.storage.add_file_db(
+ db_info.db_name, db_info.db_type, db_info.file_path
+ )
+ else:
+ self.storage.add_url_db(
+ db_info.db_name,
+ db_info.db_type,
+ db_info.db_host,
+ db_info.db_port,
+ db_info.db_user,
+ db_info.db_pwd,
+ db_info.comment,
+ )
+ # async embedding
+ thread = threading.Thread(
+ target=self.db_summary_client.db_summary_embedding(
+ db_info.db_name, db_info.db_type
+ )
)
+ thread.start()
+ except Exception as e:
+ raise ValueError("Add db connect info error!" + str(e))
+
return True
diff --git a/pilot/connections/rdbms/rdbms_connect.py b/pilot/connections/rdbms/base.py
similarity index 100%
rename from pilot/connections/rdbms/rdbms_connect.py
rename to pilot/connections/rdbms/base.py
diff --git a/pilot/connections/rdbms/clickhouse.py b/pilot/connections/rdbms/clickhouse.py
deleted file mode 100644
index fdf8e803b..000000000
--- a/pilot/connections/rdbms/clickhouse.py
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
-
-from pilot.configs.config import Config
-
-CFG = Config()
-
-
-class ClickHouseConnector(RDBMSDatabase):
- """ClickHouseConnector"""
-
- db_type: str = "duckdb"
-
- driver: str = "duckdb"
-
- file_path: str
-
- default_db = ["information_schema", "performance_schema", "sys", "mysql"]
-
- @classmethod
- def from_config(cls) -> RDBMSDatabase:
- """
- Todo password encryption
- Returns:
- """
- return cls.from_uri_db(
- cls,
- CFG.LOCAL_DB_PATH,
- engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
- )
-
- @classmethod
- def from_uri_db(
- cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
- ) -> RDBMSDatabase:
- db_url: str = cls.connect_driver + "://" + db_path
- return cls.from_uri(db_url, engine_args, **kwargs)
diff --git a/pilot/connections/rdbms/conn_duckdb.py b/pilot/connections/rdbms/conn_duckdb.py
index 90d2fc6e9..668ee9cf6 100644
--- a/pilot/connections/rdbms/conn_duckdb.py
+++ b/pilot/connections/rdbms/conn_duckdb.py
@@ -9,7 +9,7 @@ from sqlalchemy import (
)
from sqlalchemy.ext.declarative import declarative_base
-from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
+from pilot.connections.rdbms.base import RDBMSDatabase
class DuckDbConnect(RDBMSDatabase):
@@ -29,6 +29,38 @@ class DuckDbConnect(RDBMSDatabase):
_engine_args = engine_args or {}
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
+ def get_users(self):
+ cursor = self.session.execute(
+ text(
+ f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"
+ )
+ )
+ users = cursor.fetchall()
+ return [(user[0], user[1]) for user in users]
+
+ def get_grants(self):
+ return []
+
+ def get_collation(self):
+ """Get collation."""
+ return "UTF-8"
+
+ def get_charset(self):
+ return "UTF-8"
+
+ def get_table_comments(self, db_name):
+ cursor = self.session.execute(
+ text(
+ f"""
+ SELECT name, sql FROM sqlite_master WHERE type='table'
+ """
+ )
+ )
+ table_comments = cursor.fetchall()
+ return [
+ (table_comment[0], table_comment[1]) for table_comment in table_comments
+ ]
+
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
SELECT name FROM sqlite_master WHERE type='table'
diff --git a/pilot/connections/rdbms/conn_mssql.py b/pilot/connections/rdbms/conn_mssql.py
index a248f8f66..c2eb947aa 100644
--- a/pilot/connections/rdbms/conn_mssql.py
+++ b/pilot/connections/rdbms/conn_mssql.py
@@ -10,7 +10,7 @@ from sqlalchemy import (
select,
text,
)
-from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
+from pilot.connections.rdbms.base import RDBMSDatabase
class MSSQLConnect(RDBMSDatabase):
diff --git a/pilot/connections/rdbms/conn_mysql.py b/pilot/connections/rdbms/conn_mysql.py
index 196919dbe..3bdd1ee44 100644
--- a/pilot/connections/rdbms/conn_mysql.py
+++ b/pilot/connections/rdbms/conn_mysql.py
@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from typing import Optional, Any
-from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
+from pilot.connections.rdbms.base import RDBMSDatabase
class MySQLConnect(RDBMSDatabase):
diff --git a/pilot/connections/rdbms/oracle.py b/pilot/connections/rdbms/oracle.py
deleted file mode 100644
index e28c17fe5..000000000
--- a/pilot/connections/rdbms/oracle.py
+++ /dev/null
@@ -1,13 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
-
-
-class OracleConnector(RDBMSDatabase):
- """OracleConnector"""
-
- db_type: str = "oracle"
-
- driver: str = "oracle"
-
- default_db = ["SYS", "SYSTEM", "OUTLN", "ORDDATA", "XDB"]
diff --git a/pilot/connections/rdbms/postgres.py b/pilot/connections/rdbms/postgres.py
deleted file mode 100644
index 7873eaaf4..000000000
--- a/pilot/connections/rdbms/postgres.py
+++ /dev/null
@@ -1,12 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
-
-
-class PostgresConnector(RDBMSDatabase):
- """PostgresConnector is a class which Connector"""
-
- db_type: str = "postgresql"
- driver: str = "postgresql"
-
- default_db = ["information_schema", "performance_schema", "sys", "mysql"]
diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py
index 7d3482dfc..aa63df27f 100644
--- a/pilot/openapi/api_v1/api_v1.py
+++ b/pilot/openapi/api_v1/api_v1.py
@@ -25,7 +25,7 @@ from pilot.openapi.api_v1.api_view_model import (
MessageVo,
ChatSceneVo,
)
-from pilot.connections.db_conn_info import DBConfig
+from pilot.connections.db_conn_info import DBConfig, DbTypeInfo
from pilot.configs.config import Config
from pilot.server.knowledge.service import KnowledgeService
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
@@ -35,7 +35,7 @@ from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
-from pilot.scene.base_message import BaseMessage
+from pilot.common.schema import DBType
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
@@ -97,23 +97,34 @@ def knowledge_list():
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
-async def dialogue_list():
+async def db_connect_list():
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
@router.post("/v1/chat/db/add", response_model=Result[bool])
-async def dialogue_list(db_config: DBConfig = Body()):
+async def db_connect_add(db_config: DBConfig = Body()):
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
+@router.post("/v1/chat/db/edit", response_model=Result[bool])
+async def db_connect_edit(db_config: DBConfig = Body()):
+ return Result.succ(CFG.LOCAL_DB_MANAGE.edit_db(db_config))
+
+
@router.post("/v1/chat/db/delete", response_model=Result[bool])
-async def dialogue_list(db_name: str = None):
+async def db_connect_delete(db_name: str = None):
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
-@router.get("/v1/chat/db/support/type", response_model=Result[str])
+@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types():
- return Result[str].succ(["mysql", "mssql", "duckdb"])
+ support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb]
+ db_type_infos = []
+ for type in support_types:
+ db_type_infos.append(
+ DbTypeInfo(db_type=type.value(), is_file_db=type.is_file_db())
+ )
+ return Result[DbTypeInfo].succ(db_type_infos)
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
diff --git a/pilot/scene/chat_data/__init__.py b/pilot/scene/chat_data/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_data/chat_excel/__init__.py b/pilot/scene/chat_data/chat_excel/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/__init__.py b/pilot/scene/chat_data/chat_excel/excel_analyze/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/__init__.py b/pilot/scene/chat_data/chat_excel/excel_learning/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py
index 6ee691a88..e284ce1a9 100644
--- a/pilot/scene/chat_db/auto_execute/prompt.py
+++ b/pilot/scene/chat_db/auto_execute/prompt.py
@@ -8,10 +8,10 @@ from pilot.scene.chat_db.auto_execute.example import sql_data_example
CFG = Config()
-PROMPT_SCENE_DEFINE = None
+PROMPT_SCENE_DEFINE = "You are a SQL expert. "
_DEFAULT_TEMPLATE = """
-You are a SQL expert. Given an input question, create a syntactically correct {dialect} sql.
+Given an input question, create a syntactically correct {dialect} sql.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
Use as few tables as possible when querying.
diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py
index def585069..f061635a3 100644
--- a/pilot/server/dbgpt_server.py
+++ b/pilot/server/dbgpt_server.py
@@ -4,17 +4,19 @@ import os
import shutil
import argparse
import sys
+import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
import signal
from pilot.configs.config import Config
-from pilot.configs.model_config import (
- DATASETS_DIR,
- KNOWLEDGE_UPLOAD_ROOT_PATH,
- LLM_MODEL_CONFIG,
- LOGDIR,
-)
+
+# from pilot.configs.model_config import (
+# DATASETS_DIR,
+# KNOWLEDGE_UPLOAD_ROOT_PATH,
+# LLM_MODEL_CONFIG,
+# LOGDIR,
+# )
from pilot.utils import build_logger
from pilot.server.webserver_base import server_init
@@ -30,11 +32,13 @@ from pilot.server.knowledge.api import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
+logging.basicConfig(level=logging.INFO)
+
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
-logger = build_logger("webserver", LOGDIR + "webserver.log")
+# logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler():
@@ -113,5 +117,6 @@ if __name__ == "__main__":
import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=args.port)
+ logging.basicConfig(level=logging.INFO)
+ uvicorn.run(app, host="0.0.0.0", port=args.port, log_level=0)
signal.signal(signal.SIGINT, signal_handler())
diff --git a/pilot/server/static/404.html b/pilot/server/static/404.html
index 4da41deb4..2dda42bba 100644
--- a/pilot/server/static/404.html
+++ b/pilot/server/static/404.html
@@ -1 +1 @@
-