mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 04:08:10 +00:00
Multi DB support
This commit is contained in:
parent
73b3b7069f
commit
288e55a97f
1751
logsDbChatOutputParser.log.2023-07-13
Normal file
1751
logsDbChatOutputParser.log.2023-07-13
Normal file
File diff suppressed because it is too large
Load Diff
@ -12,3 +12,31 @@ class SeparatorStyle(Enum):
|
|||||||
class ExampleType(Enum):
|
class ExampleType(Enum):
|
||||||
ONE_SHOT = "one_shot"
|
ONE_SHOT = "one_shot"
|
||||||
FEW_SHOT = "few_shot"
|
FEW_SHOT = "few_shot"
|
||||||
|
|
||||||
|
|
||||||
|
class DbInfo:
|
||||||
|
def __init__(self, name, is_file_db: bool = False):
|
||||||
|
self.name = name
|
||||||
|
self.is_file_db = is_file_db
|
||||||
|
|
||||||
|
|
||||||
|
class DBType(Enum):
|
||||||
|
Mysql = DbInfo("mysql")
|
||||||
|
OCeanBase = DbInfo("oceanbase")
|
||||||
|
DuckDb = DbInfo("duckdb", True)
|
||||||
|
Oracle = DbInfo("oracle")
|
||||||
|
MSSQL = DbInfo("mssql")
|
||||||
|
Postgresql = DbInfo("postgresql")
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
return self._value_.name;
|
||||||
|
|
||||||
|
def is_file_db(self):
|
||||||
|
return self._value_.is_file_db
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def of_db_type(db_type: str):
|
||||||
|
for item in DBType.__members__:
|
||||||
|
if item.value().name == db_type:
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
@ -28,6 +28,8 @@ class Config(metaclass=Singleton):
|
|||||||
self.skip_reprompt = False
|
self.skip_reprompt = False
|
||||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||||
|
|
||||||
|
self.NUM_GPUS = int(os.getenv("NUM_GPUS",1))
|
||||||
|
|
||||||
self.execute_local_commands = (
|
self.execute_local_commands = (
|
||||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||||
)
|
)
|
||||||
@ -116,25 +118,15 @@ class Config(metaclass=Singleton):
|
|||||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
||||||
)
|
)
|
||||||
|
|
||||||
### Local database connection configuration
|
### default Local database connection configuration
|
||||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
|
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
|
||||||
|
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME")
|
||||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||||
|
|
||||||
### TODO Adapt to multiple types of libraries
|
self.LOCAL_DB_MANAGE = None
|
||||||
self.local_db = Database.from_uri(
|
|
||||||
"mysql+pymysql://"
|
|
||||||
+ self.LOCAL_DB_USER
|
|
||||||
+ ":"
|
|
||||||
+ self.LOCAL_DB_PASSWORD
|
|
||||||
+ "@"
|
|
||||||
+ self.LOCAL_DB_HOST
|
|
||||||
+ ":"
|
|
||||||
+ str(self.LOCAL_DB_PORT),
|
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
### LLM Model Service Configuration
|
### LLM Model Service Configuration
|
||||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||||
|
@ -7,11 +7,8 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
|||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class BaseConnect(BaseModel, ABC):
|
class BaseConnect( ABC):
|
||||||
type
|
def get_connect(self, db_name: str):
|
||||||
driver: str
|
|
||||||
|
|
||||||
def get_session(self, db_name: str):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_table_names(self) -> Iterable[str]:
|
def get_table_names(self) -> Iterable[str]:
|
||||||
@ -20,14 +17,41 @@ class BaseConnect(BaseModel, ABC):
|
|||||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
def get_example_data(self, table: str, count: int = 3):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_database_list(self):
|
def get_database_list(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_database_names(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_table_comments(self, db_name):
|
||||||
|
pass
|
||||||
|
|
||||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_users(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_grants(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_collation(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_charset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_fields(self, table_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_show_create_table(self, table_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_indexes(self, table_name):
|
||||||
|
pass
|
||||||
|
0
pilot/connections/manages/__init__.py
Normal file
0
pilot/connections/manages/__init__.py
Normal file
116
pilot/connections/manages/connect_storage_duckdb.py
Normal file
116
pilot/connections/manages/connect_storage_duckdb.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import os
|
||||||
|
import duckdb
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
default_db_path = os.path.join(os.getcwd(), "message")
|
||||||
|
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
|
||||||
|
table_name = "connect_config"
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbConnectConfig:
|
||||||
|
def __init__(self):
|
||||||
|
os.makedirs(default_db_path, exist_ok=True)
|
||||||
|
self.connect = duckdb.connect(duckdb_path)
|
||||||
|
self.__init_config_tables()
|
||||||
|
|
||||||
|
def __init_config_tables(self):
|
||||||
|
# check config table
|
||||||
|
result = self.connect.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
# create config table
|
||||||
|
self.connect.execute(
|
||||||
|
"CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)"
|
||||||
|
)
|
||||||
|
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
||||||
|
|
||||||
|
def add_url_db(self, db_name, db_type, db_host: str, db_port: int, db_user: str, db_pwd: str, comment: str = ""):
|
||||||
|
try:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||||
|
[db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment],
|
||||||
|
)
|
||||||
|
cursor.commit()
|
||||||
|
self.connect.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print("add db connect info error1!" + str(e))
|
||||||
|
|
||||||
|
def get_file_db_name(self, path):
|
||||||
|
try:
|
||||||
|
conn = duckdb.connect(path)
|
||||||
|
result = conn.execute('SELECT current_database()').fetchone()[0]
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise "Unusable duckdb database path:" + path
|
||||||
|
|
||||||
|
def add_file_db(self, db_name, db_type, db_path: str):
|
||||||
|
try:
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||||
|
[db_name, db_type, db_path, "", "", "", "", ""],
|
||||||
|
)
|
||||||
|
cursor.commit()
|
||||||
|
self.connect.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print("add db connect info error2!" + str(e))
|
||||||
|
|
||||||
|
def delete_db(self, db_name):
|
||||||
|
cursor = self.connect.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"DELETE FROM connect_config where db_name=?", [db_name]
|
||||||
|
)
|
||||||
|
cursor.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_db_config(self, db_name):
|
||||||
|
if os.path.isfile(duckdb_path):
|
||||||
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
|
if db_name:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM connect_config where db_name=? ", [db_name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Cannot get database by name" + db_name)
|
||||||
|
|
||||||
|
fields = [field[0] for field in cursor.description]
|
||||||
|
row_dict = {}
|
||||||
|
for row in cursor.fetchall()[0]:
|
||||||
|
for i, field in enumerate(fields):
|
||||||
|
row_dict[field] = row[i]
|
||||||
|
return row_dict
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_db_list(self):
|
||||||
|
if os.path.isfile(duckdb_path):
|
||||||
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT db_name, db_type, comment FROM connect_config "
|
||||||
|
)
|
||||||
|
|
||||||
|
fields = [field[0] for field in cursor.description]
|
||||||
|
data = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
row_dict = {}
|
||||||
|
for i, field in enumerate(fields):
|
||||||
|
row_dict[field] = row[i]
|
||||||
|
data.append(row_dict)
|
||||||
|
return data
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_names(self):
|
||||||
|
if os.path.isfile(duckdb_path):
|
||||||
|
cursor = duckdb.connect(duckdb_path).cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT db_name FROM connect_config "
|
||||||
|
)
|
||||||
|
data = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
data.append(row[0])
|
||||||
|
return data
|
||||||
|
return []
|
87
pilot/connections/manages/connection_manager.py
Normal file
87
pilot/connections/manages/connection_manager.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||||
|
from pilot.common.schema import DBType
|
||||||
|
from pilot.connections.rdbms.mysql import MySQLConnect
|
||||||
|
from pilot.connections.base import BaseConnect
|
||||||
|
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||||
|
from pilot.singleton import Singleton
|
||||||
|
from pilot.common.sql_database import Database
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectManager:
|
||||||
|
|
||||||
|
def get_instance_by_dbtype(db_type, **kwargs):
|
||||||
|
chat_classes = BaseConnect.__subclasses__()
|
||||||
|
implementation = None
|
||||||
|
for cls in chat_classes:
|
||||||
|
if cls.db_type == db_type:
|
||||||
|
implementation = cls(**kwargs)
|
||||||
|
if implementation == None:
|
||||||
|
raise Exception(f"Invalid db connect implementation!DbType:{db_type}")
|
||||||
|
return implementation
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.storage = DuckdbConnectConfig()
|
||||||
|
self.__load_config_db()
|
||||||
|
|
||||||
|
def __load_config_db(self):
|
||||||
|
if CFG.LOCAL_DB_HOST:
|
||||||
|
# default mysql
|
||||||
|
if CFG.LOCAL_DB_NAME:
|
||||||
|
self.storage.add_url_db(CFG.LOCAL_DB_NAME, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
|
||||||
|
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
|
||||||
|
else:
|
||||||
|
# get all default mysql database
|
||||||
|
default_mysql = Database.from_uri(
|
||||||
|
"mysql+pymysql://"
|
||||||
|
+ CFG.LOCAL_DB_USER
|
||||||
|
+ ":"
|
||||||
|
+ CFG.LOCAL_DB_PASSWORD
|
||||||
|
+ "@"
|
||||||
|
+ CFG.LOCAL_DB_HOST
|
||||||
|
+ ":"
|
||||||
|
+ str(CFG.LOCAL_DB_PORT),
|
||||||
|
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||||
|
)
|
||||||
|
# default_mysql = MySQLConnect.from_uri(
|
||||||
|
# "mysql+pymysql://"
|
||||||
|
# + CFG.LOCAL_DB_USER
|
||||||
|
# + ":"
|
||||||
|
# + CFG.LOCAL_DB_PASSWORD
|
||||||
|
# + "@"
|
||||||
|
# + CFG.LOCAL_DB_HOST
|
||||||
|
# + ":"
|
||||||
|
# + str(CFG.LOCAL_DB_PORT),
|
||||||
|
# engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||||
|
# )
|
||||||
|
dbs = default_mysql.get_database_list()
|
||||||
|
for name in dbs:
|
||||||
|
self.storage.add_url_db(name, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
|
||||||
|
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
|
||||||
|
if CFG.LOCAL_DB_PATH:
|
||||||
|
# default file db is duckdb
|
||||||
|
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
|
||||||
|
if db_name:
|
||||||
|
self.storage.add_file_db(db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH)
|
||||||
|
|
||||||
|
def get_connect(self, db_name):
|
||||||
|
db_config = self.storage.get_db_config(db_name)
|
||||||
|
db_type = DBType.of_db_type(db_config.get('db_type'))
|
||||||
|
connect_instance = self.get_instance_by_dbtype(db_type)
|
||||||
|
if db_type.is_file_db():
|
||||||
|
db_path = db_config.get('db_path')
|
||||||
|
return connect_instance.from_file(db_path)
|
||||||
|
else:
|
||||||
|
db_host = db_config.get('db_host')
|
||||||
|
db_port = db_config.get('db_port')
|
||||||
|
db_user = db_config.get('db_user')
|
||||||
|
db_pwd = db_config.get('db_pwd')
|
||||||
|
return connect_instance.from_uri_db(db_host, db_port, db_user, db_pwd, db_name)
|
||||||
|
|
||||||
|
def get_db_list(self):
|
||||||
|
return self.storage.get_db_list()
|
||||||
|
|
||||||
|
def get_db_names(self):
|
||||||
|
return self.storage.get_db_names()
|
@ -10,7 +10,7 @@ CFG = Config()
|
|||||||
class ClickHouseConnector(RDBMSDatabase):
|
class ClickHouseConnector(RDBMSDatabase):
|
||||||
"""ClickHouseConnector"""
|
"""ClickHouseConnector"""
|
||||||
|
|
||||||
type: str = "DUCKDB"
|
db_type: str = "duckdb"
|
||||||
|
|
||||||
driver: str = "duckdb"
|
driver: str = "duckdb"
|
||||||
|
|
||||||
|
@ -1,7 +1,14 @@
|
|||||||
from typing import Optional, Any
|
from typing import Optional, Any, Iterable
|
||||||
|
from sqlalchemy import (
|
||||||
|
MetaData,
|
||||||
|
Table,
|
||||||
|
create_engine,
|
||||||
|
inspect,
|
||||||
|
select,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -13,29 +20,15 @@ class DuckDbConnect(RDBMSDatabase):
|
|||||||
Usage:
|
Usage:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "DUCKDB"
|
def table_simple_info(self) -> Iterable[str]:
|
||||||
|
return super().get_table_names()
|
||||||
|
|
||||||
driver: str = "duckdb"
|
db_type: str = "duckdb"
|
||||||
|
|
||||||
file_path: str
|
|
||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls) -> RDBMSDatabase:
|
def from_file_path(
|
||||||
"""
|
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||||
Todo password encryption
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
return cls.from_uri_db(
|
|
||||||
cls,
|
|
||||||
CFG.LOCAL_DB_PATH,
|
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_uri_db(
|
|
||||||
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
|
||||||
) -> RDBMSDatabase:
|
) -> RDBMSDatabase:
|
||||||
db_url: str = cls.connect_driver + "://" + db_path
|
"""Construct a SQLAlchemy engine from URI."""
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
_engine_args = engine_args or {}
|
||||||
|
return cls(create_engine("duckdb://" + file_path, **_engine_args), **kwargs)
|
||||||
|
@ -11,8 +11,8 @@ class MSSQLConnect(RDBMSDatabase):
|
|||||||
Usage:
|
Usage:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "MSSQL"
|
db_type: str = "mssql"
|
||||||
dialect: str = "mssql"
|
db_dialect: str = "mssql"
|
||||||
driver: str = "pyodbc"
|
driver: str = "pyodbc"
|
||||||
|
|
||||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]
|
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]
|
||||||
|
@ -11,8 +11,8 @@ class MySQLConnect(RDBMSDatabase):
|
|||||||
Usage:
|
Usage:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "MySQL"
|
db_type: str = "mysql"
|
||||||
dialect: str = "mysql"
|
db_dialect: str = "mysql"
|
||||||
driver: str = "pymysql"
|
driver: str = "pymysql"
|
||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
|||||||
class OracleConnector(RDBMSDatabase):
|
class OracleConnector(RDBMSDatabase):
|
||||||
"""OracleConnector"""
|
"""OracleConnector"""
|
||||||
|
|
||||||
type: str = "ORACLE"
|
db_type: str = "oracle"
|
||||||
|
|
||||||
driver: str = "oracle"
|
driver: str = "oracle"
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
|||||||
class PostgresConnector(RDBMSDatabase):
|
class PostgresConnector(RDBMSDatabase):
|
||||||
"""PostgresConnector is a class which Connector"""
|
"""PostgresConnector is a class which Connector"""
|
||||||
|
|
||||||
type: str = "POSTGRESQL"
|
db_type: str = "postgresql"
|
||||||
driver: str = "postgresql"
|
driver: str = "postgresql"
|
||||||
|
|
||||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
import sqlparse
|
||||||
|
import regex as re
|
||||||
|
|
||||||
from typing import Any, Iterable, List, Optional
|
from typing import Any, Iterable, List, Optional
|
||||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -48,26 +51,11 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
if include_tables and ignore_tables:
|
if include_tables and ignore_tables:
|
||||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||||
|
|
||||||
self._inspector = inspect(self._engine)
|
self._inspector = inspect(engine)
|
||||||
session_factory = sessionmaker(bind=engine)
|
session_factory = sessionmaker(bind=engine)
|
||||||
Session = scoped_session(session_factory)
|
Session_Manages = scoped_session(session_factory)
|
||||||
|
self._db_sessions = Session_Manages
|
||||||
self._db_sessions = Session
|
self.session = self.get_session()
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls) -> RDBMSDatabase:
|
|
||||||
"""
|
|
||||||
Todo password encryption
|
|
||||||
Returns:
|
|
||||||
"""
|
|
||||||
return cls.from_uri_db(
|
|
||||||
cls,
|
|
||||||
CFG.LOCAL_DB_HOST,
|
|
||||||
CFG.LOCAL_DB_PORT,
|
|
||||||
CFG.LOCAL_DB_USER,
|
|
||||||
CFG.LOCAL_DB_PASSWORD,
|
|
||||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri_db(
|
def from_uri_db(
|
||||||
@ -76,7 +64,7 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
port: int,
|
port: int,
|
||||||
user: str,
|
user: str,
|
||||||
pwd: str,
|
pwd: str,
|
||||||
db_name: str = None,
|
db_name: str,
|
||||||
engine_args: Optional[dict] = None,
|
engine_args: Optional[dict] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> RDBMSDatabase:
|
) -> RDBMSDatabase:
|
||||||
@ -90,11 +78,9 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
+ CFG.LOCAL_DB_HOST
|
+ CFG.LOCAL_DB_HOST
|
||||||
+ ":"
|
+ ":"
|
||||||
+ str(CFG.LOCAL_DB_PORT)
|
+ str(CFG.LOCAL_DB_PORT)
|
||||||
|
+ "/"
|
||||||
|
+ db_name
|
||||||
)
|
)
|
||||||
if cls.dialect:
|
|
||||||
db_url = cls.dialect + "+" + db_url
|
|
||||||
if db_name:
|
|
||||||
db_url = db_url + "/" + db_name
|
|
||||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -123,24 +109,17 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
)
|
)
|
||||||
return self.get_usable_table_names()
|
return self.get_usable_table_names()
|
||||||
|
|
||||||
def get_session(self, db_name: str):
|
def get_session(self):
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
|
|
||||||
self._metadata = MetaData()
|
self._metadata = MetaData()
|
||||||
# sql = f"use {db_name}"
|
self._metadata.reflect(bind=self._engine)
|
||||||
sql = text(f"use `{db_name}`")
|
|
||||||
session.execute(sql)
|
|
||||||
|
|
||||||
# 处理表信息数据
|
|
||||||
|
|
||||||
self._metadata.reflect(bind=self._engine, schema=db_name)
|
|
||||||
|
|
||||||
# including view support by adding the views as well as tables to the all
|
# including view support by adding the views as well as tables to the all
|
||||||
# tables list if view_support is True
|
# tables list if view_support is True
|
||||||
self._all_tables = set(
|
self._all_tables = set(
|
||||||
self._inspector.get_table_names(schema=db_name)
|
self._inspector.get_table_names()
|
||||||
+ (
|
+ (
|
||||||
self._inspector.get_view_names(schema=db_name)
|
self._inspector.get_view_names()
|
||||||
if self.view_support
|
if self.view_support
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
@ -148,14 +127,14 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def get_current_db_name(self, session) -> str:
|
def get_current_db_name(self) -> str:
|
||||||
return session.execute(text("SELECT DATABASE()")).scalar()
|
return self.session.execute(text("SELECT DATABASE()")).scalar()
|
||||||
|
|
||||||
def table_simple_info(self, session):
|
def table_simple_info(self):
|
||||||
_sql = f"""
|
_sql = f"""
|
||||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
|
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
|
||||||
"""
|
"""
|
||||||
cursor = session.execute(text(_sql))
|
cursor = self.session.execute(text(_sql))
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -255,9 +234,31 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
"""Format the error message"""
|
"""Format the error message"""
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
def __write(self, session, write_sql):
|
||||||
"""Execute a SQL command and return a string representing the results."""
|
print(f"Write[{write_sql}]")
|
||||||
cursor = session.execute(text(command))
|
db_cache = self.get_session_db(session)
|
||||||
|
result = session.execute(text(write_sql))
|
||||||
|
session.commit()
|
||||||
|
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||||
|
session.execute(text(f"use `{db_cache}`"))
|
||||||
|
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def __query(self, session, query, fetch: str = "all"):
|
||||||
|
"""
|
||||||
|
only for query
|
||||||
|
Args:
|
||||||
|
session:
|
||||||
|
query:
|
||||||
|
fetch:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
print(f"Query[{query}]")
|
||||||
|
if not query:
|
||||||
|
return []
|
||||||
|
cursor = session.execute(text(query))
|
||||||
if cursor.returns_rows:
|
if cursor.returns_rows:
|
||||||
if fetch == "all":
|
if fetch == "all":
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
@ -271,6 +272,62 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
result.insert(0, field_names)
|
result.insert(0, field_names)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def query_ex(self, session, query, fetch: str = "all"):
|
||||||
|
"""
|
||||||
|
only for query
|
||||||
|
Args:
|
||||||
|
session:
|
||||||
|
query:
|
||||||
|
fetch:
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
print(f"Query[{query}]")
|
||||||
|
if not query:
|
||||||
|
return []
|
||||||
|
cursor = session.execute(text(query))
|
||||||
|
if cursor.returns_rows:
|
||||||
|
if fetch == "all":
|
||||||
|
result = cursor.fetchall()
|
||||||
|
elif fetch == "one":
|
||||||
|
result = cursor.fetchone()[0] # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||||
|
field_names = list(i[0:] for i in cursor.keys())
|
||||||
|
|
||||||
|
result = list(result)
|
||||||
|
return field_names, result
|
||||||
|
|
||||||
|
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||||
|
"""Execute a SQL command and return a string representing the results."""
|
||||||
|
print("SQL:" + command)
|
||||||
|
if not command:
|
||||||
|
return []
|
||||||
|
parsed, ttype, sql_type = self.__sql_parse(command)
|
||||||
|
if ttype == sqlparse.tokens.DML:
|
||||||
|
if sql_type == "SELECT":
|
||||||
|
return self.__query(session, command, fetch)
|
||||||
|
else:
|
||||||
|
self.__write(session, command)
|
||||||
|
select_sql = self.convert_sql_write_to_select(command)
|
||||||
|
print(f"write result query:{select_sql}")
|
||||||
|
return self.__query(session, select_sql)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"DDL execution determines whether to enable through configuration ")
|
||||||
|
cursor = session.execute(text(command))
|
||||||
|
session.commit()
|
||||||
|
if cursor.returns_rows:
|
||||||
|
result = cursor.fetchall()
|
||||||
|
field_names = tuple(i[0:] for i in cursor.keys())
|
||||||
|
result = list(result)
|
||||||
|
result.insert(0, field_names)
|
||||||
|
print("DDL Result:" + str(result))
|
||||||
|
if not result:
|
||||||
|
return self.__query(session, "SHOW COLUMNS FROM test")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return self.__query(session, "SHOW COLUMNS FROM test")
|
||||||
|
|
||||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||||
"""Execute a SQL command and return a string representing the results.
|
"""Execute a SQL command and return a string representing the results.
|
||||||
|
|
||||||
@ -294,3 +351,152 @@ class RDBMSDatabase(BaseConnect):
|
|||||||
for d in results
|
for d in results
|
||||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def convert_sql_write_to_select(self, write_sql):
|
||||||
|
"""
|
||||||
|
SQL classification processing
|
||||||
|
author:xiangh8
|
||||||
|
Args:
|
||||||
|
sql:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 将SQL命令转换为小写,并按空格拆分
|
||||||
|
parts = write_sql.lower().split()
|
||||||
|
# 获取命令类型(insert, delete, update)
|
||||||
|
cmd_type = parts[0]
|
||||||
|
|
||||||
|
# 根据命令类型进行处理
|
||||||
|
if cmd_type == "insert":
|
||||||
|
match = re.match(
|
||||||
|
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
table_name, columns, values = match.groups()
|
||||||
|
# 将字段列表和值列表分割为单独的字段和值
|
||||||
|
columns = columns.split(",")
|
||||||
|
values = values.split(",")
|
||||||
|
# 构造 WHERE 子句
|
||||||
|
where_clause = " AND ".join(
|
||||||
|
[
|
||||||
|
f"{col.strip()}={val.strip()}"
|
||||||
|
for col, val in zip(columns, values)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return f"SELECT * FROM {table_name} WHERE {where_clause}"
|
||||||
|
|
||||||
|
elif cmd_type == "delete":
|
||||||
|
table_name = parts[2] # delete from <table_name> ...
|
||||||
|
# 返回一个select语句,它选择该表的所有数据
|
||||||
|
return f"SELECT * FROM {table_name} "
|
||||||
|
|
||||||
|
elif cmd_type == "update":
|
||||||
|
table_name = parts[1]
|
||||||
|
set_idx = parts.index("set")
|
||||||
|
where_idx = parts.index("where")
|
||||||
|
# 截取 `set` 子句中的字段名
|
||||||
|
set_clause = parts[set_idx + 1: where_idx][0].split("=")[0].strip()
|
||||||
|
# 截取 `where` 之后的条件语句
|
||||||
|
where_clause = " ".join(parts[where_idx + 1:])
|
||||||
|
# 返回一个select语句,它选择更新的数据
|
||||||
|
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||||
|
|
||||||
|
def __sql_parse(self, sql):
|
||||||
|
sql = sql.strip()
|
||||||
|
parsed = sqlparse.parse(sql)[0]
|
||||||
|
sql_type = parsed.get_type()
|
||||||
|
|
||||||
|
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||||
|
ttype = first_token.ttype
|
||||||
|
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
||||||
|
return parsed, ttype, sql_type
|
||||||
|
|
||||||
|
def get_indexes(self, table_name):
|
||||||
|
"""Get table indexes about specified table."""
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}"))
|
||||||
|
indexes = cursor.fetchall()
|
||||||
|
return [(index[2], index[4]) for index in indexes]
|
||||||
|
|
||||||
|
def get_show_create_table(self, table_name):
|
||||||
|
"""Get table show create table about specified table."""
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||||
|
ans = cursor.fetchall()
|
||||||
|
return ans[0][1]
|
||||||
|
|
||||||
|
def get_fields(self, table_name):
|
||||||
|
"""Get column fields about specified table."""
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
|
||||||
|
table_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
fields = cursor.fetchall()
|
||||||
|
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||||
|
|
||||||
|
def get_charset(self):
|
||||||
|
"""Get character_set."""
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(f"SELECT @@character_set_database"))
|
||||||
|
character_set = cursor.fetchone()[0]
|
||||||
|
return character_set
|
||||||
|
|
||||||
|
def get_collation(self):
|
||||||
|
"""Get collation."""
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(f"SELECT @@collation_database"))
|
||||||
|
collation = cursor.fetchone()[0]
|
||||||
|
return collation
|
||||||
|
|
||||||
|
def get_grants(self):
|
||||||
|
"""Get grant info."""
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(f"SHOW GRANTS"))
|
||||||
|
grants = cursor.fetchall()
|
||||||
|
return grants
|
||||||
|
|
||||||
|
def get_users(self):
|
||||||
|
"""Get user info."""
|
||||||
|
cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user"))
|
||||||
|
users = cursor.fetchall()
|
||||||
|
return [(user[0], user[1]) for user in users]
|
||||||
|
|
||||||
|
def get_table_comments(self, db_name):
|
||||||
|
cursor = self.session.execute(
|
||||||
|
text(
|
||||||
|
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{db_name}'""".format(
|
||||||
|
db_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
table_comments = cursor.fetchall()
|
||||||
|
return [
|
||||||
|
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_database_list(self):
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(" show databases;"))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
return [
|
||||||
|
d[0]
|
||||||
|
for d in results
|
||||||
|
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_database_names(self):
|
||||||
|
session = self._db_sessions()
|
||||||
|
cursor = session.execute(text(" show databases;"))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
return [
|
||||||
|
d[0]
|
||||||
|
for d in results
|
||||||
|
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||||
|
]
|
||||||
|
@ -53,7 +53,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|||||||
message = ""
|
message = ""
|
||||||
for error in exc.errors():
|
for error in exc.errors():
|
||||||
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||||
return Result.faild(code= "E0001", msg=message)
|
return Result.faild(code="E0001", msg=message)
|
||||||
|
|
||||||
|
|
||||||
def __get_conv_user_message(conversations: dict):
|
def __get_conv_user_message(conversations: dict):
|
||||||
@ -71,11 +71,10 @@ def __new_conversation(chat_mode, user_id) -> ConversationVo:
|
|||||||
|
|
||||||
|
|
||||||
def get_db_list():
|
def get_db_list():
|
||||||
db = CFG.local_db
|
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||||
dbs = db.get_database_list()
|
|
||||||
params: dict = {}
|
params: dict = {}
|
||||||
for name in dbs:
|
for item in dbs:
|
||||||
params.update({name: name})
|
params.update({item["db_name"]: item["comment"]})
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
@ -95,8 +94,9 @@ def knowledge_list():
|
|||||||
params.update({space.name: space.name})
|
params.update({space.name: space.name})
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||||
async def dialogue_list( user_id: str = None):
|
async def dialogue_list(user_id: str = None):
|
||||||
dialogues: List = []
|
dialogues: List = []
|
||||||
datas = DuckdbHistoryMemory.conv_list(user_id)
|
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||||
|
|
||||||
@ -126,11 +126,10 @@ async def dialogue_scenes():
|
|||||||
ChatScene.ChatExecution,
|
ChatScene.ChatExecution,
|
||||||
]
|
]
|
||||||
for scene in new_modes:
|
for scene in new_modes:
|
||||||
|
|
||||||
scene_vo = ChatSceneVo(
|
scene_vo = ChatSceneVo(
|
||||||
chat_scene=scene.value(),
|
chat_scene=scene.value(),
|
||||||
scene_name=scene.scene_name(),
|
scene_name=scene.scene_name(),
|
||||||
scene_describe= scene.describe(),
|
scene_describe=scene.describe(),
|
||||||
param_title=",".join(scene.param_types()),
|
param_title=",".join(scene.param_types()),
|
||||||
)
|
)
|
||||||
scene_vos.append(scene_vo)
|
scene_vos.append(scene_vo)
|
||||||
|
@ -43,7 +43,7 @@ class ChatDashboard(BaseChat):
|
|||||||
)
|
)
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
self.report_name = report_name
|
self.report_name = report_name
|
||||||
self.database = CFG.local_db
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
# 准备DB信息(拿到指定库的链接)
|
# 准备DB信息(拿到指定库的链接)
|
||||||
self.db_connect = self.database.get_session(self.db_name)
|
self.db_connect = self.database.get_session(self.db_name)
|
||||||
self.top_k: int = 5
|
self.top_k: int = 5
|
||||||
|
@ -33,9 +33,8 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||||
)
|
)
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
self.database = CFG.local_db
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
# 准备DB信息(拿到指定库的链接)
|
self.db_connect = self.database.session
|
||||||
self.db_connect = self.database.get_session(self.db_name)
|
|
||||||
self.top_k: int = 5
|
self.top_k: int = 5
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
|
@ -28,9 +28,8 @@ class ChatWithDbQA(BaseChat):
|
|||||||
)
|
)
|
||||||
self.db_name = db_name
|
self.db_name = db_name
|
||||||
if db_name:
|
if db_name:
|
||||||
self.database = CFG.local_db
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||||
# 准备DB信息(拿到指定库的链接)
|
self.db_connect = self.database.session
|
||||||
self.db_connect = self.database.get_session(self.db_name)
|
|
||||||
self.tables = self.database.get_table_names()
|
self.tables = self.database.get_table_names()
|
||||||
|
|
||||||
self.top_k = (
|
self.top_k = (
|
||||||
|
@ -84,12 +84,6 @@ priority = {"vicuna-13b": "aaa"}
|
|||||||
|
|
||||||
CHAT_FACTORY = ChatFactory()
|
CHAT_FACTORY = ChatFactory()
|
||||||
|
|
||||||
DB_SETTINGS = {
|
|
||||||
"user": CFG.LOCAL_DB_USER,
|
|
||||||
"password": CFG.LOCAL_DB_PASSWORD,
|
|
||||||
"host": CFG.LOCAL_DB_HOST,
|
|
||||||
"port": CFG.LOCAL_DB_PORT,
|
|
||||||
}
|
|
||||||
|
|
||||||
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
|
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
|
||||||
default_knowledge_base_dialogue = get_lang_text(
|
default_knowledge_base_dialogue = get_lang_text(
|
||||||
@ -144,14 +138,6 @@ def get_simlar(q):
|
|||||||
return "\n".join(contents)
|
return "\n".join(contents)
|
||||||
|
|
||||||
|
|
||||||
def gen_sqlgen_conversation(dbname):
|
|
||||||
message = ""
|
|
||||||
db_connect = CFG.local_db.get_session(dbname)
|
|
||||||
schemas = CFG.local_db.table_simple_info(db_connect)
|
|
||||||
for s in schemas:
|
|
||||||
message += s + ";"
|
|
||||||
return get_lang_text("sql_schema_info").format(dbname, message)
|
|
||||||
|
|
||||||
|
|
||||||
def plugins_select_info():
|
def plugins_select_info():
|
||||||
plugins_infos: dict = {}
|
plugins_infos: dict = {}
|
||||||
@ -703,7 +689,7 @@ if __name__ == "__main__":
|
|||||||
# init server config
|
# init server config
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server_init(args)
|
server_init(args)
|
||||||
dbs = CFG.local_db.get_database_list()
|
dbs = CFG.LOCAL_DB_MANAGE.get_db_names()
|
||||||
demo = build_webdemo()
|
demo = build_webdemo()
|
||||||
demo.queue(
|
demo.queue(
|
||||||
concurrency_count=args.concurrency_count,
|
concurrency_count=args.concurrency_count,
|
||||||
|
@ -15,6 +15,7 @@ from pilot.configs.model_config import (
|
|||||||
)
|
)
|
||||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
|
from pilot.connections.manages.connection_manager import ConnectManager
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
@ -38,6 +39,10 @@ def server_init(args):
|
|||||||
|
|
||||||
# init config
|
# init config
|
||||||
cfg = Config()
|
cfg = Config()
|
||||||
|
# init connect manage
|
||||||
|
conn_manage = ConnectManager()
|
||||||
|
cfg.LOCAL_DB_MANAGE = conn_manage
|
||||||
|
|
||||||
load_native_plugins(cfg)
|
load_native_plugins(cfg)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
async_db_summery()
|
async_db_summery()
|
||||||
|
@ -11,6 +11,8 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
|||||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||||
from pilot.summary.mysql_db_summary import MysqlSummary
|
from pilot.summary.mysql_db_summary import MysqlSummary
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
|
from pilot.common.schema import DBType
|
||||||
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
chat_factory = ChatFactory()
|
chat_factory = ChatFactory()
|
||||||
@ -24,10 +26,13 @@ class DBSummaryClient:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def db_summary_embedding(self, dbname):
|
def db_summary_embedding(self, dbname, db_type):
|
||||||
"""put db profile and table profile summary into vector store"""
|
"""put db profile and table profile summary into vector store"""
|
||||||
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
|
if DBType.Mysql.value() == db_type:
|
||||||
db_summary_client = MysqlSummary(dbname)
|
db_summary_client = MysqlSummary(dbname)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupport summary DbType!" + db_type)
|
||||||
|
|
||||||
embeddings = HuggingFaceEmbeddings(
|
embeddings = HuggingFaceEmbeddings(
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||||
)
|
)
|
||||||
@ -120,10 +125,10 @@ class DBSummaryClient:
|
|||||||
return related_table_summaries
|
return related_table_summaries
|
||||||
|
|
||||||
def init_db_summary(self):
|
def init_db_summary(self):
|
||||||
db = CFG.local_db
|
db_mange = CFG.LOCAL_DB_MANAGE
|
||||||
dbs = db.get_database_list()
|
dbs = db_mange.get_db_list()
|
||||||
for dbname in dbs:
|
for item in dbs:
|
||||||
self.db_summary_embedding(dbname)
|
self.db_summary_embedding(item["db_name"], item["db_type"])
|
||||||
|
|
||||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||||
profile_store_config = {
|
profile_store_config = {
|
||||||
|
@ -56,8 +56,8 @@ class MysqlSummary(DBSummary):
|
|||||||
self.vector_tables_info = []
|
self.vector_tables_info = []
|
||||||
# self.tables_summary = {}
|
# self.tables_summary = {}
|
||||||
|
|
||||||
self.db = CFG.local_db
|
self.db = CFG.LOCAL_DB_MANAGE.get_connect()
|
||||||
self.db.get_session(name)
|
|
||||||
|
|
||||||
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
|
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
|
||||||
users=self.db.get_users(),
|
users=self.db.get_users(),
|
||||||
|
BIN
pilot/xx.db
Normal file
BIN
pilot/xx.db
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user