Multi DB support

This commit is contained in:
yhjun1026 2023-07-21 10:15:32 +08:00
parent 73b3b7069f
commit 288e55a97f
23 changed files with 2347 additions and 157 deletions

File diff suppressed because it is too large Load Diff

View File

@ -12,3 +12,31 @@ class SeparatorStyle(Enum):
class ExampleType(Enum):
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"
class DbInfo:
def __init__(self, name, is_file_db: bool = False):
self.name = name
self.is_file_db = is_file_db
class DBType(Enum):
Mysql = DbInfo("mysql")
OCeanBase = DbInfo("oceanbase")
DuckDb = DbInfo("duckdb", True)
Oracle = DbInfo("oracle")
MSSQL = DbInfo("mssql")
Postgresql = DbInfo("postgresql")
def value(self):
return self._value_.name;
def is_file_db(self):
return self._value_.is_file_db
@staticmethod
def of_db_type(db_type: str):
for item in DBType.__members__:
if item.value().name == db_type:
return item
return None

View File

@ -28,6 +28,8 @@ class Config(metaclass=Singleton):
self.skip_reprompt = False
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
self.NUM_GPUS = int(os.getenv("NUM_GPUS",1))
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
)
@ -116,25 +118,15 @@ class Config(metaclass=Singleton):
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
)
### Local database connection configuration
### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
### TODO Adapt to multiple types of libraries
self.local_db = Database.from_uri(
"mysql+pymysql://"
+ self.LOCAL_DB_USER
+ ":"
+ self.LOCAL_DB_PASSWORD
+ "@"
+ self.LOCAL_DB_HOST
+ ":"
+ str(self.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
self.LOCAL_DB_MANAGE = None
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")

View File

@ -7,11 +7,8 @@ from pydantic import BaseModel, Extra, Field, root_validator
from typing import Any, Iterable, List, Optional
class BaseConnect(BaseModel, ABC):
type
driver: str
def get_session(self, db_name: str):
class BaseConnect( ABC):
def get_connect(self, db_name: str):
pass
def get_table_names(self) -> Iterable[str]:
@ -20,14 +17,41 @@ class BaseConnect(BaseModel, ABC):
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
pass
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
pass
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
def get_example_data(self, table: str, count: int = 3):
pass
def get_database_list(self):
pass
def get_database_names(self):
pass
def get_table_comments(self, db_name):
pass
def run(self, session, command: str, fetch: str = "all") -> List:
pass
def get_users(self):
pass
def get_grants(self):
pass
def get_collation(self):
pass
def get_charset(self):
pass
def get_fields(self, table_name):
pass
def get_show_create_table(self, table_name):
pass
def get_indexes(self, table_name):
pass

View File

View File

@ -0,0 +1,116 @@
import os
import duckdb
from typing import List
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
table_name = "connect_config"
class DuckdbConnectConfig:
def __init__(self):
os.makedirs(default_db_path, exist_ok=True)
self.connect = duckdb.connect(duckdb_path)
self.__init_config_tables()
def __init_config_tables(self):
# check config table
result = self.connect.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result:
# create config table
self.connect.execute(
"CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)"
)
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
def add_url_db(self, db_name, db_type, db_host: str, db_port: int, db_user: str, db_pwd: str, comment: str = ""):
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
[db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment],
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("add db connect info error1" + str(e))
def get_file_db_name(self, path):
try:
conn = duckdb.connect(path)
result = conn.execute('SELECT current_database()').fetchone()[0]
return result
except Exception as e:
raise "Unusable duckdb database path:" + path
def add_file_db(self, db_name, db_type, db_path: str):
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
[db_name, db_type, db_path, "", "", "", "", ""],
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("add db connect info error2" + str(e))
def delete_db(self, db_name):
cursor = self.connect.cursor()
cursor.execute(
"DELETE FROM connect_config where db_name=?", [db_name]
)
cursor.commit()
return True
def get_db_config(self, db_name):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if db_name:
cursor.execute(
"SELECT * FROM connect_config where db_name=? ", [db_name]
)
else:
raise ValueError("Cannot get database by name" + db_name)
fields = [field[0] for field in cursor.description]
row_dict = {}
for row in cursor.fetchall()[0]:
for i, field in enumerate(fields):
row_dict[field] = row[i]
return row_dict
return {}
def get_db_list(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute(
"SELECT db_name, db_type, comment FROM connect_config "
)
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
return []
def get_db_names(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute(
"SELECT db_name FROM connect_config "
)
data = []
for row in cursor.fetchall():
data.append(row[0])
return data
return []

View File

@ -0,0 +1,87 @@
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.connections.rdbms.mysql import MySQLConnect
from pilot.connections.base import BaseConnect
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
CFG = Config()
class ConnectManager:
def get_instance_by_dbtype(db_type, **kwargs):
chat_classes = BaseConnect.__subclasses__()
implementation = None
for cls in chat_classes:
if cls.db_type == db_type:
implementation = cls(**kwargs)
if implementation == None:
raise Exception(f"Invalid db connect implementationDbType:{db_type}")
return implementation
def __init__(self):
self.storage = DuckdbConnectConfig()
self.__load_config_db()
def __load_config_db(self):
if CFG.LOCAL_DB_HOST:
# default mysql
if CFG.LOCAL_DB_NAME:
self.storage.add_url_db(CFG.LOCAL_DB_NAME, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
else:
# get all default mysql database
default_mysql = Database.from_uri(
"mysql+pymysql://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
# default_mysql = MySQLConnect.from_uri(
# "mysql+pymysql://"
# + CFG.LOCAL_DB_USER
# + ":"
# + CFG.LOCAL_DB_PASSWORD
# + "@"
# + CFG.LOCAL_DB_HOST
# + ":"
# + str(CFG.LOCAL_DB_PORT),
# engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
# )
dbs = default_mysql.get_database_list()
for name in dbs:
self.storage.add_url_db(name, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
if CFG.LOCAL_DB_PATH:
# default file db is duckdb
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
if db_name:
self.storage.add_file_db(db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH)
def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name)
db_type = DBType.of_db_type(db_config.get('db_type'))
connect_instance = self.get_instance_by_dbtype(db_type)
if db_type.is_file_db():
db_path = db_config.get('db_path')
return connect_instance.from_file(db_path)
else:
db_host = db_config.get('db_host')
db_port = db_config.get('db_port')
db_user = db_config.get('db_user')
db_pwd = db_config.get('db_pwd')
return connect_instance.from_uri_db(db_host, db_port, db_user, db_pwd, db_name)
def get_db_list(self):
return self.storage.get_db_list()
def get_db_names(self):
return self.storage.get_db_names()

View File

@ -10,7 +10,7 @@ CFG = Config()
class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector"""
type: str = "DUCKDB"
db_type: str = "duckdb"
driver: str = "duckdb"

View File

@ -1,7 +1,14 @@
from typing import Optional, Any
from typing import Optional, Any, Iterable
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.configs.config import Config
CFG = Config()
@ -13,29 +20,15 @@ class DuckDbConnect(RDBMSDatabase):
Usage:
"""
type: str = "DUCKDB"
def table_simple_info(self) -> Iterable[str]:
return super().get_table_names()
driver: str = "duckdb"
file_path: str
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
db_type: str = "duckdb"
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod
def from_uri_db(
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path
return cls.from_uri(db_url, engine_args, **kwargs)
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine("duckdb://" + file_path, **_engine_args), **kwargs)

View File

@ -11,8 +11,8 @@ class MSSQLConnect(RDBMSDatabase):
Usage:
"""
type: str = "MSSQL"
dialect: str = "mssql"
db_type: str = "mssql"
db_dialect: str = "mssql"
driver: str = "pyodbc"
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]

View File

@ -11,8 +11,8 @@ class MySQLConnect(RDBMSDatabase):
Usage:
"""
type: str = "MySQL"
dialect: str = "mysql"
db_type: str = "mysql"
db_dialect: str = "mysql"
driver: str = "pymysql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector(RDBMSDatabase):
"""OracleConnector"""
type: str = "ORACLE"
db_type: str = "oracle"
driver: str = "oracle"

View File

@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector"""
type: str = "POSTGRESQL"
db_type: str = "postgresql"
driver: str = "postgresql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import warnings
import sqlparse
import regex as re
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
@ -35,12 +38,12 @@ class RDBMSDatabase(BaseConnect):
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
):
"""Create engine from database URI."""
self._engine = engine
@ -48,58 +51,41 @@ class RDBMSDatabase(BaseConnect):
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
self._inspector = inspect(engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
self._db_sessions = Session
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
Session_Manages = scoped_session(session_factory)
self._db_sessions = Session_Manages
self.session = self.get_session()
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.connect_driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
cls.connect_driver
+ "://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ "/"
+ db_name
)
if cls.dialect:
db_url = cls.dialect + "+" + db_url
if db_name:
db_url = db_url + "/" + db_name
return cls.from_uri(db_url, engine_args, **kwargs)
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
@ -123,24 +109,17 @@ class RDBMSDatabase(BaseConnect):
)
return self.get_usable_table_names()
def get_session(self, db_name: str):
def get_session(self):
session = self._db_sessions()
self._metadata = MetaData()
# sql = f"use {db_name}"
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
self._metadata.reflect(bind=self._engine, schema=db_name)
self._metadata.reflect(bind=self._engine)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=db_name)
self._inspector.get_table_names()
+ (
self._inspector.get_view_names(schema=db_name)
self._inspector.get_view_names()
if self.view_support
else []
)
@ -148,14 +127,14 @@ class RDBMSDatabase(BaseConnect):
return session
def get_current_db_name(self, session) -> str:
return session.execute(text("SELECT DATABASE()")).scalar()
def get_current_db_name(self) -> str:
return self.session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session):
def table_simple_info(self):
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
"""
cursor = session.execute(text(_sql))
cursor = self.session.execute(text(_sql))
results = cursor.fetchall()
return results
@ -185,7 +164,7 @@ class RDBMSDatabase(BaseConnect):
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
@ -198,7 +177,7 @@ class RDBMSDatabase(BaseConnect):
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
@ -255,9 +234,31 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message"""
return f"Error: {e}"
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
cursor = session.execute(text(command))
def __write(self, session, write_sql):
print(f"Write[{write_sql}]")
db_cache = self.get_session_db(session)
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self, session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
@ -271,6 +272,62 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
return result
def query_ex(self, session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
result = list(result)
return field_names, result
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
if not command:
return []
parsed, ttype, sql_type = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(session, command, fetch)
else:
self.__write(session, command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(session, select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command))
session.commit()
if cursor.returns_rows:
result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query(session, "SHOW COLUMNS FROM test")
return result
else:
return self.__query(session, "SHOW COLUMNS FROM test")
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.
@ -294,3 +351,152 @@ class RDBMSDatabase(BaseConnect):
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]
def convert_sql_write_to_select(self, write_sql):
"""
SQL classification processing
author:xiangh8
Args:
sql:
Returns:
"""
# 将SQL命令转换为小写并按空格拆分
parts = write_sql.lower().split()
# 获取命令类型insert, delete, update
cmd_type = parts[0]
# 根据命令类型进行处理
if cmd_type == "insert":
match = re.match(
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
)
if match:
table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值
columns = columns.split(",")
values = values.split(",")
# 构造 WHERE 子句
where_clause = " AND ".join(
[
f"{col.strip()}={val.strip()}"
for col, val in zip(columns, values)
]
)
return f"SELECT * FROM {table_name} WHERE {where_clause}"
elif cmd_type == "delete":
table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据
return f"SELECT * FROM {table_name} "
elif cmd_type == "update":
table_name = parts[1]
set_idx = parts.index("set")
where_idx = parts.index("where")
# 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句
where_clause = " ".join(parts[where_idx + 1:])
# 返回一个select语句它选择更新的数据
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
def __sql_parse(self, sql):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
return parsed, ttype, sql_type
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}"))
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
ans = cursor.fetchall()
return ans[0][1]
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
table_name
)
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_charset(self):
"""Get character_set."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@character_set_database"))
character_set = cursor.fetchone()[0]
return character_set
def get_collation(self):
"""Get collation."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@collation_database"))
collation = cursor.fetchone()[0]
return collation
def get_grants(self):
"""Get grant info."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW GRANTS"))
grants = cursor.fetchall()
return grants
def get_users(self):
"""Get user info."""
cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user"))
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
def get_table_comments(self, db_name):
cursor = self.session.execute(
text(
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{db_name}'""".format(
db_name
)
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]
def get_database_names(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]

View File

@ -53,7 +53,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
message = ""
for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return Result.faild(code= "E0001", msg=message)
return Result.faild(code="E0001", msg=message)
def __get_conv_user_message(conversations: dict):
@ -71,11 +71,10 @@ def __new_conversation(chat_mode, user_id) -> ConversationVo:
def get_db_list():
db = CFG.local_db
dbs = db.get_database_list()
dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
params: dict = {}
for name in dbs:
params.update({name: name})
for item in dbs:
params.update({item["db_name"]: item["comment"]})
return params
@ -95,8 +94,9 @@ def knowledge_list():
params.update({space.name: space.name})
return params
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list( user_id: str = None):
async def dialogue_list(user_id: str = None):
dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id)
@ -126,11 +126,10 @@ async def dialogue_scenes():
ChatScene.ChatExecution,
]
for scene in new_modes:
scene_vo = ChatSceneVo(
chat_scene=scene.value(),
scene_name=scene.scene_name(),
scene_describe= scene.describe(),
scene_describe=scene.describe(),
param_title=",".join(scene.param_types()),
)
scene_vos.append(scene_vo)

View File

@ -43,7 +43,7 @@ class ChatDashboard(BaseChat):
)
self.db_name = db_name
self.report_name = report_name
self.database = CFG.local_db
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5

View File

@ -33,9 +33,8 @@ class ChatWithDbAutoExecute(BaseChat):
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
)
self.db_name = db_name
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
self.db_connect = self.database.session
self.top_k: int = 5
def generate_input_values(self):

View File

@ -28,9 +28,8 @@ class ChatWithDbQA(BaseChat):
)
self.db_name = db_name
if db_name:
self.database = CFG.local_db
# 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name)
self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
self.db_connect = self.database.session
self.tables = self.database.get_table_names()
self.top_k = (

View File

@ -84,12 +84,6 @@ priority = {"vicuna-13b": "aaa"}
CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT,
}
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
default_knowledge_base_dialogue = get_lang_text(
@ -144,14 +138,6 @@ def get_simlar(q):
return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
message = ""
db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s + ";"
return get_lang_text("sql_schema_info").format(dbname, message)
def plugins_select_info():
plugins_infos: dict = {}
@ -703,7 +689,7 @@ if __name__ == "__main__":
# init server config
args = parser.parse_args()
server_init(args)
dbs = CFG.local_db.get_database_list()
dbs = CFG.LOCAL_DB_MANAGE.get_db_names()
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count,

View File

@ -15,6 +15,7 @@ from pilot.configs.model_config import (
)
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.utils import build_logger
from pilot.connections.manages.connection_manager import ConnectManager
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@ -38,6 +39,10 @@ def server_init(args):
# init config
cfg = Config()
# init connect manage
conn_manage = ConnectManager()
cfg.LOCAL_DB_MANAGE = conn_manage
load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()

View File

@ -11,6 +11,8 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.scene.chat_factory import ChatFactory
from pilot.common.schema import DBType
CFG = Config()
chat_factory = ChatFactory()
@ -24,10 +26,13 @@ class DBSummaryClient:
def __init__(self):
pass
def db_summary_embedding(self, dbname):
def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store"""
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None:
if DBType.Mysql.value() == db_type:
db_summary_client = MysqlSummary(dbname)
else:
raise ValueError("Unsupport summary DbType" + db_type)
embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
@ -120,10 +125,10 @@ class DBSummaryClient:
return related_table_summaries
def init_db_summary(self):
db = CFG.local_db
dbs = db.get_database_list()
for dbname in dbs:
self.db_summary_embedding(dbname)
db_mange = CFG.LOCAL_DB_MANAGE
dbs = db_mange.get_db_list()
for item in dbs:
self.db_summary_embedding(item["db_name"], item["db_type"])
def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = {

View File

@ -56,8 +56,8 @@ class MysqlSummary(DBSummary):
self.vector_tables_info = []
# self.tables_summary = {}
self.db = CFG.local_db
self.db.get_session(name)
self.db = CFG.LOCAL_DB_MANAGE.get_connect()
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(),

BIN
pilot/xx.db Normal file

Binary file not shown.