mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
Multi DB support
This commit is contained in:
parent
73b3b7069f
commit
288e55a97f
1751
logsDbChatOutputParser.log.2023-07-13
Normal file
1751
logsDbChatOutputParser.log.2023-07-13
Normal file
File diff suppressed because it is too large
Load Diff
@ -12,3 +12,31 @@ class SeparatorStyle(Enum):
|
||||
class ExampleType(Enum):
|
||||
ONE_SHOT = "one_shot"
|
||||
FEW_SHOT = "few_shot"
|
||||
|
||||
|
||||
class DbInfo:
|
||||
def __init__(self, name, is_file_db: bool = False):
|
||||
self.name = name
|
||||
self.is_file_db = is_file_db
|
||||
|
||||
|
||||
class DBType(Enum):
|
||||
Mysql = DbInfo("mysql")
|
||||
OCeanBase = DbInfo("oceanbase")
|
||||
DuckDb = DbInfo("duckdb", True)
|
||||
Oracle = DbInfo("oracle")
|
||||
MSSQL = DbInfo("mssql")
|
||||
Postgresql = DbInfo("postgresql")
|
||||
|
||||
def value(self):
|
||||
return self._value_.name;
|
||||
|
||||
def is_file_db(self):
|
||||
return self._value_.is_file_db
|
||||
|
||||
@staticmethod
|
||||
def of_db_type(db_type: str):
|
||||
for item in DBType.__members__:
|
||||
if item.value().name == db_type:
|
||||
return item
|
||||
return None
|
||||
|
@ -28,6 +28,8 @@ class Config(metaclass=Singleton):
|
||||
self.skip_reprompt = False
|
||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||
|
||||
self.NUM_GPUS = int(os.getenv("NUM_GPUS",1))
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
)
|
||||
@ -116,25 +118,15 @@ class Config(metaclass=Singleton):
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
||||
)
|
||||
|
||||
### Local database connection configuration
|
||||
### default Local database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
|
||||
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME")
|
||||
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
|
||||
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
|
||||
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
|
||||
|
||||
### TODO Adapt to multiple types of libraries
|
||||
self.local_db = Database.from_uri(
|
||||
"mysql+pymysql://"
|
||||
+ self.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ self.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ self.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(self.LOCAL_DB_PORT),
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
self.LOCAL_DB_MANAGE = None
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
|
@ -7,11 +7,8 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
class BaseConnect(BaseModel, ABC):
|
||||
type
|
||||
driver: str
|
||||
|
||||
def get_session(self, db_name: str):
|
||||
class BaseConnect( ABC):
|
||||
def get_connect(self, db_name: str):
|
||||
pass
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
@ -20,14 +17,41 @@ class BaseConnect(BaseModel, ABC):
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
def get_example_data(self, table: str, count: int = 3):
|
||||
pass
|
||||
|
||||
def get_database_list(self):
|
||||
pass
|
||||
|
||||
def get_database_names(self):
|
||||
pass
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
pass
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
pass
|
||||
|
||||
def get_users(self):
|
||||
pass
|
||||
|
||||
def get_grants(self):
|
||||
pass
|
||||
|
||||
def get_collation(self):
|
||||
pass
|
||||
|
||||
def get_charset(self):
|
||||
pass
|
||||
|
||||
def get_fields(self, table_name):
|
||||
pass
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
pass
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
pass
|
||||
|
0
pilot/connections/manages/__init__.py
Normal file
0
pilot/connections/manages/__init__.py
Normal file
116
pilot/connections/manages/connect_storage_duckdb.py
Normal file
116
pilot/connections/manages/connect_storage_duckdb.py
Normal file
@ -0,0 +1,116 @@
|
||||
import os
|
||||
import duckdb
|
||||
from typing import List
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
|
||||
table_name = "connect_config"
|
||||
|
||||
|
||||
class DuckdbConnectConfig:
|
||||
def __init__(self):
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
self.connect = duckdb.connect(duckdb_path)
|
||||
self.__init_config_tables()
|
||||
|
||||
def __init_config_tables(self):
|
||||
# check config table
|
||||
result = self.connect.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||
).fetchall()
|
||||
|
||||
if not result:
|
||||
# create config table
|
||||
self.connect.execute(
|
||||
"CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)"
|
||||
)
|
||||
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
||||
|
||||
def add_url_db(self, db_name, db_type, db_host: str, db_port: int, db_user: str, db_pwd: str, comment: str = ""):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||
[db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("add db connect info error1!" + str(e))
|
||||
|
||||
def get_file_db_name(self, path):
|
||||
try:
|
||||
conn = duckdb.connect(path)
|
||||
result = conn.execute('SELECT current_database()').fetchone()[0]
|
||||
return result
|
||||
except Exception as e:
|
||||
raise "Unusable duckdb database path:" + path
|
||||
|
||||
def add_file_db(self, db_name, db_type, db_path: str):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||
[db_name, db_type, db_path, "", "", "", "", ""],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("add db connect info error2!" + str(e))
|
||||
|
||||
def delete_db(self, db_name):
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM connect_config where db_name=?", [db_name]
|
||||
)
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
def get_db_config(self, db_name):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if db_name:
|
||||
cursor.execute(
|
||||
"SELECT * FROM connect_config where db_name=? ", [db_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Cannot get database by name" + db_name)
|
||||
|
||||
fields = [field[0] for field in cursor.description]
|
||||
row_dict = {}
|
||||
for row in cursor.fetchall()[0]:
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
return row_dict
|
||||
return {}
|
||||
|
||||
def get_db_list(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute(
|
||||
"SELECT db_name, db_type, comment FROM connect_config "
|
||||
)
|
||||
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
return data
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_db_names(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute(
|
||||
"SELECT db_name FROM connect_config "
|
||||
)
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
data.append(row[0])
|
||||
return data
|
||||
return []
|
87
pilot/connections/manages/connection_manager.py
Normal file
87
pilot/connections/manages/connection_manager.py
Normal file
@ -0,0 +1,87 @@
|
||||
from pilot.configs.config import Config
|
||||
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.connections.rdbms.mysql import MySQLConnect
|
||||
from pilot.connections.base import BaseConnect
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ConnectManager:
|
||||
|
||||
def get_instance_by_dbtype(db_type, **kwargs):
|
||||
chat_classes = BaseConnect.__subclasses__()
|
||||
implementation = None
|
||||
for cls in chat_classes:
|
||||
if cls.db_type == db_type:
|
||||
implementation = cls(**kwargs)
|
||||
if implementation == None:
|
||||
raise Exception(f"Invalid db connect implementation!DbType:{db_type}")
|
||||
return implementation
|
||||
|
||||
def __init__(self):
|
||||
self.storage = DuckdbConnectConfig()
|
||||
self.__load_config_db()
|
||||
|
||||
def __load_config_db(self):
|
||||
if CFG.LOCAL_DB_HOST:
|
||||
# default mysql
|
||||
if CFG.LOCAL_DB_NAME:
|
||||
self.storage.add_url_db(CFG.LOCAL_DB_NAME, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
|
||||
else:
|
||||
# get all default mysql database
|
||||
default_mysql = Database.from_uri(
|
||||
"mysql+pymysql://"
|
||||
+ CFG.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ CFG.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT),
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
# default_mysql = MySQLConnect.from_uri(
|
||||
# "mysql+pymysql://"
|
||||
# + CFG.LOCAL_DB_USER
|
||||
# + ":"
|
||||
# + CFG.LOCAL_DB_PASSWORD
|
||||
# + "@"
|
||||
# + CFG.LOCAL_DB_HOST
|
||||
# + ":"
|
||||
# + str(CFG.LOCAL_DB_PORT),
|
||||
# engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
# )
|
||||
dbs = default_mysql.get_database_list()
|
||||
for name in dbs:
|
||||
self.storage.add_url_db(name, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
|
||||
if CFG.LOCAL_DB_PATH:
|
||||
# default file db is duckdb
|
||||
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
|
||||
if db_name:
|
||||
self.storage.add_file_db(db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH)
|
||||
|
||||
def get_connect(self, db_name):
|
||||
db_config = self.storage.get_db_config(db_name)
|
||||
db_type = DBType.of_db_type(db_config.get('db_type'))
|
||||
connect_instance = self.get_instance_by_dbtype(db_type)
|
||||
if db_type.is_file_db():
|
||||
db_path = db_config.get('db_path')
|
||||
return connect_instance.from_file(db_path)
|
||||
else:
|
||||
db_host = db_config.get('db_host')
|
||||
db_port = db_config.get('db_port')
|
||||
db_user = db_config.get('db_user')
|
||||
db_pwd = db_config.get('db_pwd')
|
||||
return connect_instance.from_uri_db(db_host, db_port, db_user, db_pwd, db_name)
|
||||
|
||||
def get_db_list(self):
|
||||
return self.storage.get_db_list()
|
||||
|
||||
def get_db_names(self):
|
||||
return self.storage.get_db_names()
|
@ -10,7 +10,7 @@ CFG = Config()
|
||||
class ClickHouseConnector(RDBMSDatabase):
|
||||
"""ClickHouseConnector"""
|
||||
|
||||
type: str = "DUCKDB"
|
||||
db_type: str = "duckdb"
|
||||
|
||||
driver: str = "duckdb"
|
||||
|
||||
|
@ -1,7 +1,14 @@
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Iterable
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
@ -13,29 +20,15 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "DUCKDB"
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
return super().get_table_names()
|
||||
|
||||
driver: str = "duckdb"
|
||||
|
||||
file_path: str
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
db_type: str = "duckdb"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(
|
||||
cls,
|
||||
CFG.LOCAL_DB_PATH,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = cls.connect_driver + "://" + db_path
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine("duckdb://" + file_path, **_engine_args), **kwargs)
|
||||
|
@ -11,8 +11,8 @@ class MSSQLConnect(RDBMSDatabase):
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "MSSQL"
|
||||
dialect: str = "mssql"
|
||||
db_type: str = "mssql"
|
||||
db_dialect: str = "mssql"
|
||||
driver: str = "pyodbc"
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]
|
||||
|
@ -11,8 +11,8 @@ class MySQLConnect(RDBMSDatabase):
|
||||
Usage:
|
||||
"""
|
||||
|
||||
type: str = "MySQL"
|
||||
dialect: str = "mysql"
|
||||
db_type: str = "mysql"
|
||||
db_dialect: str = "mysql"
|
||||
driver: str = "pymysql"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
class OracleConnector(RDBMSDatabase):
|
||||
"""OracleConnector"""
|
||||
|
||||
type: str = "ORACLE"
|
||||
db_type: str = "oracle"
|
||||
|
||||
driver: str = "oracle"
|
||||
|
||||
|
@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
class PostgresConnector(RDBMSDatabase):
|
||||
"""PostgresConnector is a class which Connector"""
|
||||
|
||||
type: str = "POSTGRESQL"
|
||||
db_type: str = "postgresql"
|
||||
driver: str = "postgresql"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
|
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
import sqlparse
|
||||
import regex as re
|
||||
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from pydantic import BaseModel, Field, root_validator, validator, Extra
|
||||
from abc import ABC, abstractmethod
|
||||
@ -35,12 +38,12 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@ -48,58 +51,41 @@ class RDBMSDatabase(BaseConnect):
|
||||
if include_tables and ignore_tables:
|
||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||
|
||||
self._inspector = inspect(self._engine)
|
||||
self._inspector = inspect(engine)
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
Session = scoped_session(session_factory)
|
||||
|
||||
self._db_sessions = Session
|
||||
|
||||
@classmethod
|
||||
def from_config(cls) -> RDBMSDatabase:
|
||||
"""
|
||||
Todo password encryption
|
||||
Returns:
|
||||
"""
|
||||
return cls.from_uri_db(
|
||||
cls,
|
||||
CFG.LOCAL_DB_HOST,
|
||||
CFG.LOCAL_DB_PORT,
|
||||
CFG.LOCAL_DB_USER,
|
||||
CFG.LOCAL_DB_PASSWORD,
|
||||
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
|
||||
)
|
||||
Session_Manages = scoped_session(session_factory)
|
||||
self._db_sessions = Session_Manages
|
||||
self.session = self.get_session()
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
cls.connect_driver
|
||||
+ "://"
|
||||
+ CFG.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ CFG.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
cls.connect_driver
|
||||
+ "://"
|
||||
+ CFG.LOCAL_DB_USER
|
||||
+ ":"
|
||||
+ CFG.LOCAL_DB_PASSWORD
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
+ "/"
|
||||
+ db_name
|
||||
)
|
||||
if cls.dialect:
|
||||
db_url = cls.dialect + "+" + db_url
|
||||
if db_name:
|
||||
db_url = db_url + "/" + db_name
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
@ -123,24 +109,17 @@ class RDBMSDatabase(BaseConnect):
|
||||
)
|
||||
return self.get_usable_table_names()
|
||||
|
||||
def get_session(self, db_name: str):
|
||||
def get_session(self):
|
||||
session = self._db_sessions()
|
||||
|
||||
self._metadata = MetaData()
|
||||
# sql = f"use {db_name}"
|
||||
sql = text(f"use `{db_name}`")
|
||||
session.execute(sql)
|
||||
|
||||
# 处理表信息数据
|
||||
|
||||
self._metadata.reflect(bind=self._engine, schema=db_name)
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=db_name)
|
||||
self._inspector.get_table_names()
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=db_name)
|
||||
self._inspector.get_view_names()
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
@ -148,14 +127,14 @@ class RDBMSDatabase(BaseConnect):
|
||||
|
||||
return session
|
||||
|
||||
def get_current_db_name(self, session) -> str:
|
||||
return session.execute(text("SELECT DATABASE()")).scalar()
|
||||
def get_current_db_name(self) -> str:
|
||||
return self.session.execute(text("SELECT DATABASE()")).scalar()
|
||||
|
||||
def table_simple_info(self, session):
|
||||
def table_simple_info(self):
|
||||
_sql = f"""
|
||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
|
||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
|
||||
"""
|
||||
cursor = session.execute(text(_sql))
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
@ -185,7 +164,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
@ -198,7 +177,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
@ -255,9 +234,31 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
cursor = session.execute(text(command))
|
||||
def __write(self, session, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self.get_session_db(session)
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self, session, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
@ -271,6 +272,62 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def query_ex(self, session, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
Returns:
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(i[0:] for i in cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
if not command:
|
||||
return []
|
||||
parsed, ttype, sql_type = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self.__query(session, command, fetch)
|
||||
else:
|
||||
self.__write(session, command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query(session, select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
cursor = session.execute(text(command))
|
||||
session.commit()
|
||||
if cursor.returns_rows:
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
return self.__query(session, "SHOW COLUMNS FROM test")
|
||||
return result
|
||||
else:
|
||||
return self.__query(session, "SHOW COLUMNS FROM test")
|
||||
|
||||
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
@ -294,3 +351,152 @@ class RDBMSDatabase(BaseConnect):
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
||||
def convert_sql_write_to_select(self, write_sql):
|
||||
"""
|
||||
SQL classification processing
|
||||
author:xiangh8
|
||||
Args:
|
||||
sql:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 将SQL命令转换为小写,并按空格拆分
|
||||
parts = write_sql.lower().split()
|
||||
# 获取命令类型(insert, delete, update)
|
||||
cmd_type = parts[0]
|
||||
|
||||
# 根据命令类型进行处理
|
||||
if cmd_type == "insert":
|
||||
match = re.match(
|
||||
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
|
||||
)
|
||||
if match:
|
||||
table_name, columns, values = match.groups()
|
||||
# 将字段列表和值列表分割为单独的字段和值
|
||||
columns = columns.split(",")
|
||||
values = values.split(",")
|
||||
# 构造 WHERE 子句
|
||||
where_clause = " AND ".join(
|
||||
[
|
||||
f"{col.strip()}={val.strip()}"
|
||||
for col, val in zip(columns, values)
|
||||
]
|
||||
)
|
||||
return f"SELECT * FROM {table_name} WHERE {where_clause}"
|
||||
|
||||
elif cmd_type == "delete":
|
||||
table_name = parts[2] # delete from <table_name> ...
|
||||
# 返回一个select语句,它选择该表的所有数据
|
||||
return f"SELECT * FROM {table_name} "
|
||||
|
||||
elif cmd_type == "update":
|
||||
table_name = parts[1]
|
||||
set_idx = parts.index("set")
|
||||
where_idx = parts.index("where")
|
||||
# 截取 `set` 子句中的字段名
|
||||
set_clause = parts[set_idx + 1: where_idx][0].split("=")[0].strip()
|
||||
# 截取 `where` 之后的条件语句
|
||||
where_clause = " ".join(parts[where_idx + 1:])
|
||||
# 返回一个select语句,它选择更新的数据
|
||||
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||
|
||||
def __sql_parse(self, sql):
|
||||
sql = sql.strip()
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
sql_type = parsed.get_type()
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
|
||||
return parsed, ttype, sql_type
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[2], index[4]) for index in indexes]
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||
ans = cursor.fetchall()
|
||||
return ans[0][1]
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
|
||||
table_name
|
||||
)
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@character_set_database"))
|
||||
character_set = cursor.fetchone()[0]
|
||||
return character_set
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@collation_database"))
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grant info."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
return grants
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user"))
|
||||
users = cursor.fetchall()
|
||||
return [(user[0], user[1]) for user in users]
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
||||
def get_database_names(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
@ -53,7 +53,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
||||
message = ""
|
||||
for error in exc.errors():
|
||||
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||
return Result.faild(code= "E0001", msg=message)
|
||||
return Result.faild(code="E0001", msg=message)
|
||||
|
||||
|
||||
def __get_conv_user_message(conversations: dict):
|
||||
@ -71,11 +71,10 @@ def __new_conversation(chat_mode, user_id) -> ConversationVo:
|
||||
|
||||
|
||||
def get_db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
|
||||
params: dict = {}
|
||||
for name in dbs:
|
||||
params.update({name: name})
|
||||
for item in dbs:
|
||||
params.update({item["db_name"]: item["comment"]})
|
||||
return params
|
||||
|
||||
|
||||
@ -95,8 +94,9 @@ def knowledge_list():
|
||||
params.update({space.name: space.name})
|
||||
return params
|
||||
|
||||
|
||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||
async def dialogue_list( user_id: str = None):
|
||||
async def dialogue_list(user_id: str = None):
|
||||
dialogues: List = []
|
||||
datas = DuckdbHistoryMemory.conv_list(user_id)
|
||||
|
||||
@ -126,11 +126,10 @@ async def dialogue_scenes():
|
||||
ChatScene.ChatExecution,
|
||||
]
|
||||
for scene in new_modes:
|
||||
|
||||
scene_vo = ChatSceneVo(
|
||||
chat_scene=scene.value(),
|
||||
scene_name=scene.scene_name(),
|
||||
scene_describe= scene.describe(),
|
||||
scene_describe=scene.describe(),
|
||||
param_title=",".join(scene.param_types()),
|
||||
)
|
||||
scene_vos.append(scene_vo)
|
||||
|
@ -43,7 +43,7 @@ class ChatDashboard(BaseChat):
|
||||
)
|
||||
self.db_name = db_name
|
||||
self.report_name = report_name
|
||||
self.database = CFG.local_db
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.top_k: int = 5
|
||||
|
@ -33,9 +33,8 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||
)
|
||||
self.db_name = db_name
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
self.db_connect = self.database.session
|
||||
self.top_k: int = 5
|
||||
|
||||
def generate_input_values(self):
|
||||
|
@ -28,9 +28,8 @@ class ChatWithDbQA(BaseChat):
|
||||
)
|
||||
self.db_name = db_name
|
||||
if db_name:
|
||||
self.database = CFG.local_db
|
||||
# 准备DB信息(拿到指定库的链接)
|
||||
self.db_connect = self.database.get_session(self.db_name)
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
|
||||
self.db_connect = self.database.session
|
||||
self.tables = self.database.get_table_names()
|
||||
|
||||
self.top_k = (
|
||||
|
@ -84,12 +84,6 @@ priority = {"vicuna-13b": "aaa"}
|
||||
|
||||
CHAT_FACTORY = ChatFactory()
|
||||
|
||||
DB_SETTINGS = {
|
||||
"user": CFG.LOCAL_DB_USER,
|
||||
"password": CFG.LOCAL_DB_PASSWORD,
|
||||
"host": CFG.LOCAL_DB_HOST,
|
||||
"port": CFG.LOCAL_DB_PORT,
|
||||
}
|
||||
|
||||
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
|
||||
default_knowledge_base_dialogue = get_lang_text(
|
||||
@ -144,14 +138,6 @@ def get_simlar(q):
|
||||
return "\n".join(contents)
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
message = ""
|
||||
db_connect = CFG.local_db.get_session(dbname)
|
||||
schemas = CFG.local_db.table_simple_info(db_connect)
|
||||
for s in schemas:
|
||||
message += s + ";"
|
||||
return get_lang_text("sql_schema_info").format(dbname, message)
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
@ -703,7 +689,7 @@ if __name__ == "__main__":
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
server_init(args)
|
||||
dbs = CFG.local_db.get_database_list()
|
||||
dbs = CFG.LOCAL_DB_MANAGE.get_db_names()
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count,
|
||||
|
@ -15,6 +15,7 @@ from pilot.configs.model_config import (
|
||||
)
|
||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||
from pilot.utils import build_logger
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@ -38,6 +39,10 @@ def server_init(args):
|
||||
|
||||
# init config
|
||||
cfg = Config()
|
||||
# init connect manage
|
||||
conn_manage = ConnectManager()
|
||||
cfg.LOCAL_DB_MANAGE = conn_manage
|
||||
|
||||
load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
|
@ -11,6 +11,8 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
from pilot.summary.mysql_db_summary import MysqlSummary
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.common.schema import DBType
|
||||
|
||||
|
||||
CFG = Config()
|
||||
chat_factory = ChatFactory()
|
||||
@ -24,10 +26,13 @@ class DBSummaryClient:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def db_summary_embedding(self, dbname):
|
||||
def db_summary_embedding(self, dbname, db_type):
|
||||
"""put db profile and table profile summary into vector store"""
|
||||
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
|
||||
if DBType.Mysql.value() == db_type:
|
||||
db_summary_client = MysqlSummary(dbname)
|
||||
else:
|
||||
raise ValueError("Unsupport summary DbType!" + db_type)
|
||||
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
@ -120,10 +125,10 @@ class DBSummaryClient:
|
||||
return related_table_summaries
|
||||
|
||||
def init_db_summary(self):
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
for dbname in dbs:
|
||||
self.db_summary_embedding(dbname)
|
||||
db_mange = CFG.LOCAL_DB_MANAGE
|
||||
dbs = db_mange.get_db_list()
|
||||
for item in dbs:
|
||||
self.db_summary_embedding(item["db_name"], item["db_type"])
|
||||
|
||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||
profile_store_config = {
|
||||
|
@ -56,8 +56,8 @@ class MysqlSummary(DBSummary):
|
||||
self.vector_tables_info = []
|
||||
# self.tables_summary = {}
|
||||
|
||||
self.db = CFG.local_db
|
||||
self.db.get_session(name)
|
||||
self.db = CFG.LOCAL_DB_MANAGE.get_connect()
|
||||
|
||||
|
||||
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
|
||||
users=self.db.get_users(),
|
||||
|
BIN
pilot/xx.db
Normal file
BIN
pilot/xx.db
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user