Multi DB support

This commit is contained in:
yhjun1026 2023-07-21 10:15:32 +08:00
parent 73b3b7069f
commit 288e55a97f
23 changed files with 2347 additions and 157 deletions

File diff suppressed because it is too large Load Diff

View File

@ -12,3 +12,31 @@ class SeparatorStyle(Enum):
class ExampleType(Enum): class ExampleType(Enum):
ONE_SHOT = "one_shot" ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot" FEW_SHOT = "few_shot"
class DbInfo:
def __init__(self, name, is_file_db: bool = False):
self.name = name
self.is_file_db = is_file_db
class DBType(Enum):
Mysql = DbInfo("mysql")
OCeanBase = DbInfo("oceanbase")
DuckDb = DbInfo("duckdb", True)
Oracle = DbInfo("oracle")
MSSQL = DbInfo("mssql")
Postgresql = DbInfo("postgresql")
def value(self):
return self._value_.name;
def is_file_db(self):
return self._value_.is_file_db
@staticmethod
def of_db_type(db_type: str):
for item in DBType.__members__:
if item.value().name == db_type:
return item
return None

View File

@ -28,6 +28,8 @@ class Config(metaclass=Singleton):
self.skip_reprompt = False self.skip_reprompt = False
self.temperature = float(os.getenv("TEMPERATURE", 0.7)) self.temperature = float(os.getenv("TEMPERATURE", 0.7))
self.NUM_GPUS = int(os.getenv("NUM_GPUS",1))
self.execute_local_commands = ( self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
) )
@ -116,25 +118,15 @@ class Config(metaclass=Singleton):
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True" os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
) )
### Local database connection configuration ### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1") self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST", "127.0.0.1")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db") self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "xx.db")
self.LOCAL_DB_NAME = os.getenv("LOCAL_DB_NAME")
self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306)) self.LOCAL_DB_PORT = int(os.getenv("LOCAL_DB_PORT", 3306))
self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root") self.LOCAL_DB_USER = os.getenv("LOCAL_DB_USER", "root")
self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456") self.LOCAL_DB_PASSWORD = os.getenv("LOCAL_DB_PASSWORD", "aa123456")
### TODO Adapt to multiple types of libraries self.LOCAL_DB_MANAGE = None
self.local_db = Database.from_uri(
"mysql+pymysql://"
+ self.LOCAL_DB_USER
+ ":"
+ self.LOCAL_DB_PASSWORD
+ "@"
+ self.LOCAL_DB_HOST
+ ":"
+ str(self.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
### LLM Model Service Configuration ### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")

View File

@ -7,11 +7,8 @@ from pydantic import BaseModel, Extra, Field, root_validator
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
class BaseConnect(BaseModel, ABC): class BaseConnect( ABC):
type def get_connect(self, db_name: str):
driver: str
def get_session(self, db_name: str):
pass pass
def get_table_names(self) -> Iterable[str]: def get_table_names(self) -> Iterable[str]:
@ -20,14 +17,41 @@ class BaseConnect(BaseModel, ABC):
def get_table_info(self, table_names: Optional[List[str]] = None) -> str: def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
pass pass
def get_table_info(self, table_names: Optional[List[str]] = None) -> str: def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
pass pass
def get_index_info(self, table_names: Optional[List[str]] = None) -> str: def get_example_data(self, table: str, count: int = 3):
pass pass
def get_database_list(self): def get_database_list(self):
pass pass
def get_database_names(self):
pass
def get_table_comments(self, db_name):
pass
def run(self, session, command: str, fetch: str = "all") -> List: def run(self, session, command: str, fetch: str = "all") -> List:
pass pass
def get_users(self):
pass
def get_grants(self):
pass
def get_collation(self):
pass
def get_charset(self):
pass
def get_fields(self, table_name):
pass
def get_show_create_table(self, table_name):
pass
def get_indexes(self, table_name):
pass

View File

View File

@ -0,0 +1,116 @@
import os
import duckdb
from typing import List
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
table_name = "connect_config"
class DuckdbConnectConfig:
def __init__(self):
os.makedirs(default_db_path, exist_ok=True)
self.connect = duckdb.connect(duckdb_path)
self.__init_config_tables()
def __init_config_tables(self):
# check config table
result = self.connect.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
).fetchall()
if not result:
# create config table
self.connect.execute(
"CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)"
)
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
def add_url_db(self, db_name, db_type, db_host: str, db_port: int, db_user: str, db_pwd: str, comment: str = ""):
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
[db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment],
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("add db connect info error1" + str(e))
def get_file_db_name(self, path):
try:
conn = duckdb.connect(path)
result = conn.execute('SELECT current_database()').fetchone()[0]
return result
except Exception as e:
raise "Unusable duckdb database path:" + path
def add_file_db(self, db_name, db_type, db_path: str):
try:
cursor = self.connect.cursor()
cursor.execute(
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
[db_name, db_type, db_path, "", "", "", "", ""],
)
cursor.commit()
self.connect.commit()
except Exception as e:
print("add db connect info error2" + str(e))
def delete_db(self, db_name):
cursor = self.connect.cursor()
cursor.execute(
"DELETE FROM connect_config where db_name=?", [db_name]
)
cursor.commit()
return True
def get_db_config(self, db_name):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
if db_name:
cursor.execute(
"SELECT * FROM connect_config where db_name=? ", [db_name]
)
else:
raise ValueError("Cannot get database by name" + db_name)
fields = [field[0] for field in cursor.description]
row_dict = {}
for row in cursor.fetchall()[0]:
for i, field in enumerate(fields):
row_dict[field] = row[i]
return row_dict
return {}
def get_db_list(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute(
"SELECT db_name, db_type, comment FROM connect_config "
)
fields = [field[0] for field in cursor.description]
data = []
for row in cursor.fetchall():
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
return []
def get_db_names(self):
if os.path.isfile(duckdb_path):
cursor = duckdb.connect(duckdb_path).cursor()
cursor.execute(
"SELECT db_name FROM connect_config "
)
data = []
for row in cursor.fetchall():
data.append(row[0])
return data
return []

View File

@ -0,0 +1,87 @@
from pilot.configs.config import Config
from pilot.connections.manages.connect_storage_duckdb import DuckdbConnectConfig
from pilot.common.schema import DBType
from pilot.connections.rdbms.mysql import MySQLConnect
from pilot.connections.base import BaseConnect
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
CFG = Config()
class ConnectManager:
def get_instance_by_dbtype(db_type, **kwargs):
chat_classes = BaseConnect.__subclasses__()
implementation = None
for cls in chat_classes:
if cls.db_type == db_type:
implementation = cls(**kwargs)
if implementation == None:
raise Exception(f"Invalid db connect implementationDbType:{db_type}")
return implementation
def __init__(self):
self.storage = DuckdbConnectConfig()
self.__load_config_db()
def __load_config_db(self):
if CFG.LOCAL_DB_HOST:
# default mysql
if CFG.LOCAL_DB_NAME:
self.storage.add_url_db(CFG.LOCAL_DB_NAME, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
else:
# get all default mysql database
default_mysql = Database.from_uri(
"mysql+pymysql://"
+ CFG.LOCAL_DB_USER
+ ":"
+ CFG.LOCAL_DB_PASSWORD
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT),
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
# default_mysql = MySQLConnect.from_uri(
# "mysql+pymysql://"
# + CFG.LOCAL_DB_USER
# + ":"
# + CFG.LOCAL_DB_PASSWORD
# + "@"
# + CFG.LOCAL_DB_HOST
# + ":"
# + str(CFG.LOCAL_DB_PORT),
# engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
# )
dbs = default_mysql.get_database_list()
for name in dbs:
self.storage.add_url_db(name, DBType.Mysql.value(), CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, "")
if CFG.LOCAL_DB_PATH:
# default file db is duckdb
db_name = self.storage.get_file_db_name(CFG.LOCAL_DB_PATH)
if db_name:
self.storage.add_file_db(db_name, DBType.DuckDb.value(), CFG.LOCAL_DB_PATH)
def get_connect(self, db_name):
db_config = self.storage.get_db_config(db_name)
db_type = DBType.of_db_type(db_config.get('db_type'))
connect_instance = self.get_instance_by_dbtype(db_type)
if db_type.is_file_db():
db_path = db_config.get('db_path')
return connect_instance.from_file(db_path)
else:
db_host = db_config.get('db_host')
db_port = db_config.get('db_port')
db_user = db_config.get('db_user')
db_pwd = db_config.get('db_pwd')
return connect_instance.from_uri_db(db_host, db_port, db_user, db_pwd, db_name)
def get_db_list(self):
return self.storage.get_db_list()
def get_db_names(self):
return self.storage.get_db_names()

View File

@ -10,7 +10,7 @@ CFG = Config()
class ClickHouseConnector(RDBMSDatabase): class ClickHouseConnector(RDBMSDatabase):
"""ClickHouseConnector""" """ClickHouseConnector"""
type: str = "DUCKDB" db_type: str = "duckdb"
driver: str = "duckdb" driver: str = "duckdb"

View File

@ -1,7 +1,14 @@
from typing import Optional, Any from typing import Optional, Any, Iterable
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
from pilot.configs.config import Config from pilot.configs.config import Config
CFG = Config() CFG = Config()
@ -13,29 +20,15 @@ class DuckDbConnect(RDBMSDatabase):
Usage: Usage:
""" """
type: str = "DUCKDB" def table_simple_info(self) -> Iterable[str]:
return super().get_table_names()
driver: str = "duckdb" db_type: str = "duckdb"
file_path: str
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
@classmethod @classmethod
def from_config(cls) -> RDBMSDatabase: def from_file_path(
""" cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
Todo password encryption
Returns:
"""
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_PATH,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod
def from_uri_db(
cls, db_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase: ) -> RDBMSDatabase:
db_url: str = cls.connect_driver + "://" + db_path """Construct a SQLAlchemy engine from URI."""
return cls.from_uri(db_url, engine_args, **kwargs) _engine_args = engine_args or {}
return cls(create_engine("duckdb://" + file_path, **_engine_args), **kwargs)

View File

@ -11,8 +11,8 @@ class MSSQLConnect(RDBMSDatabase):
Usage: Usage:
""" """
type: str = "MSSQL" db_type: str = "mssql"
dialect: str = "mssql" db_dialect: str = "mssql"
driver: str = "pyodbc" driver: str = "pyodbc"
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"] default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]

View File

@ -11,8 +11,8 @@ class MySQLConnect(RDBMSDatabase):
Usage: Usage:
""" """
type: str = "MySQL" db_type: str = "mysql"
dialect: str = "mysql" db_dialect: str = "mysql"
driver: str = "pymysql" driver: str = "pymysql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class OracleConnector(RDBMSDatabase): class OracleConnector(RDBMSDatabase):
"""OracleConnector""" """OracleConnector"""
type: str = "ORACLE" db_type: str = "oracle"
driver: str = "oracle" driver: str = "oracle"

View File

@ -6,7 +6,7 @@ from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
class PostgresConnector(RDBMSDatabase): class PostgresConnector(RDBMSDatabase):
"""PostgresConnector is a class which Connector""" """PostgresConnector is a class which Connector"""
type: str = "POSTGRESQL" db_type: str = "postgresql"
driver: str = "postgresql" driver: str = "postgresql"
default_db = ["information_schema", "performance_schema", "sys", "mysql"] default_db = ["information_schema", "performance_schema", "sys", "mysql"]

View File

@ -1,6 +1,9 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
import sqlparse
import regex as re
from typing import Any, Iterable, List, Optional from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -48,26 +51,11 @@ class RDBMSDatabase(BaseConnect):
if include_tables and ignore_tables: if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables") raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine) self._inspector = inspect(engine)
session_factory = sessionmaker(bind=engine) session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory) Session_Manages = scoped_session(session_factory)
self._db_sessions = Session_Manages
self._db_sessions = Session self.session = self.get_session()
@classmethod
def from_config(cls) -> RDBMSDatabase:
"""
Todo password encryption
Returns:
"""
return cls.from_uri_db(
cls,
CFG.LOCAL_DB_HOST,
CFG.LOCAL_DB_PORT,
CFG.LOCAL_DB_USER,
CFG.LOCAL_DB_PASSWORD,
engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True},
)
@classmethod @classmethod
def from_uri_db( def from_uri_db(
@ -76,7 +64,7 @@ class RDBMSDatabase(BaseConnect):
port: int, port: int,
user: str, user: str,
pwd: str, pwd: str,
db_name: str = None, db_name: str,
engine_args: Optional[dict] = None, engine_args: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> RDBMSDatabase: ) -> RDBMSDatabase:
@ -90,11 +78,9 @@ class RDBMSDatabase(BaseConnect):
+ CFG.LOCAL_DB_HOST + CFG.LOCAL_DB_HOST
+ ":" + ":"
+ str(CFG.LOCAL_DB_PORT) + str(CFG.LOCAL_DB_PORT)
+ "/"
+ db_name
) )
if cls.dialect:
db_url = cls.dialect + "+" + db_url
if db_name:
db_url = db_url + "/" + db_name
return cls.from_uri(db_url, engine_args, **kwargs) return cls.from_uri(db_url, engine_args, **kwargs)
@classmethod @classmethod
@ -123,24 +109,17 @@ class RDBMSDatabase(BaseConnect):
) )
return self.get_usable_table_names() return self.get_usable_table_names()
def get_session(self, db_name: str): def get_session(self):
session = self._db_sessions() session = self._db_sessions()
self._metadata = MetaData() self._metadata = MetaData()
# sql = f"use {db_name}" self._metadata.reflect(bind=self._engine)
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
self._metadata.reflect(bind=self._engine, schema=db_name)
# including view support by adding the views as well as tables to the all # including view support by adding the views as well as tables to the all
# tables list if view_support is True # tables list if view_support is True
self._all_tables = set( self._all_tables = set(
self._inspector.get_table_names(schema=db_name) self._inspector.get_table_names()
+ ( + (
self._inspector.get_view_names(schema=db_name) self._inspector.get_view_names()
if self.view_support if self.view_support
else [] else []
) )
@ -148,14 +127,14 @@ class RDBMSDatabase(BaseConnect):
return session return session
def get_current_db_name(self, session) -> str: def get_current_db_name(self) -> str:
return session.execute(text("SELECT DATABASE()")).scalar() return self.session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session): def table_simple_info(self):
_sql = f""" _sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME; select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
""" """
cursor = session.execute(text(_sql)) cursor = self.session.execute(text(_sql))
results = cursor.fetchall() results = cursor.fetchall()
return results return results
@ -255,9 +234,31 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message""" """Format the error message"""
return f"Error: {e}" return f"Error: {e}"
def run(self, session, command: str, fetch: str = "all") -> List: def __write(self, session, write_sql):
"""Execute a SQL command and return a string representing the results.""" print(f"Write[{write_sql}]")
cursor = session.execute(text(command)) db_cache = self.get_session_db(session)
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self, session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows: if cursor.returns_rows:
if fetch == "all": if fetch == "all":
result = cursor.fetchall() result = cursor.fetchall()
@ -271,6 +272,62 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names) result.insert(0, field_names)
return result return result
def query_ex(self, session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
result = list(result)
return field_names, result
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
if not command:
return []
parsed, ttype, sql_type = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query(session, command, fetch)
else:
self.__write(session, command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query(session, select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = session.execute(text(command))
session.commit()
if cursor.returns_rows:
result = cursor.fetchall()
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query(session, "SHOW COLUMNS FROM test")
return result
else:
return self.__query(session, "SHOW COLUMNS FROM test")
def run_no_throw(self, session, command: str, fetch: str = "all") -> List: def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results. """Execute a SQL command and return a string representing the results.
@ -294,3 +351,152 @@ class RDBMSDatabase(BaseConnect):
for d in results for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
] ]
def convert_sql_write_to_select(self, write_sql):
"""
SQL classification processing
author:xiangh8
Args:
sql:
Returns:
"""
# 将SQL命令转换为小写并按空格拆分
parts = write_sql.lower().split()
# 获取命令类型insert, delete, update
cmd_type = parts[0]
# 根据命令类型进行处理
if cmd_type == "insert":
match = re.match(
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
)
if match:
table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值
columns = columns.split(",")
values = values.split(",")
# 构造 WHERE 子句
where_clause = " AND ".join(
[
f"{col.strip()}={val.strip()}"
for col, val in zip(columns, values)
]
)
return f"SELECT * FROM {table_name} WHERE {where_clause}"
elif cmd_type == "delete":
table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据
return f"SELECT * FROM {table_name} "
elif cmd_type == "update":
table_name = parts[1]
set_idx = parts.index("set")
where_idx = parts.index("where")
# 截取 `set` 子句中的字段名
set_clause = parts[set_idx + 1: where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句
where_clause = " ".join(parts[where_idx + 1:])
# 返回一个select语句它选择更新的数据
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
def __sql_parse(self, sql):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}")
return parsed, ttype, sql_type
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}"))
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
ans = cursor.fetchall()
return ans[0][1]
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
table_name
)
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_charset(self):
"""Get character_set."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@character_set_database"))
character_set = cursor.fetchone()[0]
return character_set
def get_collation(self):
"""Get collation."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@collation_database"))
collation = cursor.fetchone()[0]
return collation
def get_grants(self):
"""Get grant info."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW GRANTS"))
grants = cursor.fetchall()
return grants
def get_users(self):
"""Get user info."""
cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user"))
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
def get_table_comments(self, db_name):
cursor = self.session.execute(
text(
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{db_name}'""".format(
db_name
)
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]
def get_database_names(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]

View File

@ -53,7 +53,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
message = "" message = ""
for error in exc.errors(): for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";" message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return Result.faild(code= "E0001", msg=message) return Result.faild(code="E0001", msg=message)
def __get_conv_user_message(conversations: dict): def __get_conv_user_message(conversations: dict):
@ -71,11 +71,10 @@ def __new_conversation(chat_mode, user_id) -> ConversationVo:
def get_db_list(): def get_db_list():
db = CFG.local_db dbs = CFG.LOCAL_DB_MANAGE.get_db_list()
dbs = db.get_database_list()
params: dict = {} params: dict = {}
for name in dbs: for item in dbs:
params.update({name: name}) params.update({item["db_name"]: item["comment"]})
return params return params
@ -95,8 +94,9 @@ def knowledge_list():
params.update({space.name: space.name}) params.update({space.name: space.name})
return params return params
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo]) @router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
async def dialogue_list( user_id: str = None): async def dialogue_list(user_id: str = None):
dialogues: List = [] dialogues: List = []
datas = DuckdbHistoryMemory.conv_list(user_id) datas = DuckdbHistoryMemory.conv_list(user_id)
@ -126,11 +126,10 @@ async def dialogue_scenes():
ChatScene.ChatExecution, ChatScene.ChatExecution,
] ]
for scene in new_modes: for scene in new_modes:
scene_vo = ChatSceneVo( scene_vo = ChatSceneVo(
chat_scene=scene.value(), chat_scene=scene.value(),
scene_name=scene.scene_name(), scene_name=scene.scene_name(),
scene_describe= scene.describe(), scene_describe=scene.describe(),
param_title=",".join(scene.param_types()), param_title=",".join(scene.param_types()),
) )
scene_vos.append(scene_vo) scene_vos.append(scene_vo)

View File

@ -43,7 +43,7 @@ class ChatDashboard(BaseChat):
) )
self.db_name = db_name self.db_name = db_name
self.report_name = report_name self.report_name = report_name
self.database = CFG.local_db self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
# 准备DB信息(拿到指定库的链接) # 准备DB信息(拿到指定库的链接)
self.db_connect = self.database.get_session(self.db_name) self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5 self.top_k: int = 5

View File

@ -33,9 +33,8 @@ class ChatWithDbAutoExecute(BaseChat):
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
) )
self.db_name = db_name self.db_name = db_name
self.database = CFG.local_db self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
# 准备DB信息(拿到指定库的链接) self.db_connect = self.database.session
self.db_connect = self.database.get_session(self.db_name)
self.top_k: int = 5 self.top_k: int = 5
def generate_input_values(self): def generate_input_values(self):

View File

@ -28,9 +28,8 @@ class ChatWithDbQA(BaseChat):
) )
self.db_name = db_name self.db_name = db_name
if db_name: if db_name:
self.database = CFG.local_db self.database = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
# 准备DB信息(拿到指定库的链接) self.db_connect = self.database.session
self.db_connect = self.database.get_session(self.db_name)
self.tables = self.database.get_table_names() self.tables = self.database.get_table_names()
self.top_k = ( self.top_k = (

View File

@ -84,12 +84,6 @@ priority = {"vicuna-13b": "aaa"}
CHAT_FACTORY = ChatFactory() CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
"password": CFG.LOCAL_DB_PASSWORD,
"host": CFG.LOCAL_DB_HOST,
"port": CFG.LOCAL_DB_PORT,
}
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue") llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
default_knowledge_base_dialogue = get_lang_text( default_knowledge_base_dialogue = get_lang_text(
@ -144,14 +138,6 @@ def get_simlar(q):
return "\n".join(contents) return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
message = ""
db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s + ";"
return get_lang_text("sql_schema_info").format(dbname, message)
def plugins_select_info(): def plugins_select_info():
plugins_infos: dict = {} plugins_infos: dict = {}
@ -703,7 +689,7 @@ if __name__ == "__main__":
# init server config # init server config
args = parser.parse_args() args = parser.parse_args()
server_init(args) server_init(args)
dbs = CFG.local_db.get_database_list() dbs = CFG.LOCAL_DB_MANAGE.get_db_names()
demo = build_webdemo() demo = build_webdemo()
demo.queue( demo.queue(
concurrency_count=args.concurrency_count, concurrency_count=args.concurrency_count,

View File

@ -15,6 +15,7 @@ from pilot.configs.model_config import (
) )
from pilot.common.plugins import scan_plugins, load_native_plugins from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.utils import build_logger from pilot.utils import build_logger
from pilot.connections.manages.connection_manager import ConnectManager
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
@ -38,6 +39,10 @@ def server_init(args):
# init config # init config
cfg = Config() cfg = Config()
# init connect manage
conn_manage = ConnectManager()
cfg.LOCAL_DB_MANAGE = conn_manage
load_native_plugins(cfg) load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
async_db_summery() async_db_summery()

View File

@ -11,6 +11,8 @@ from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
from pilot.embedding_engine.string_embedding import StringEmbedding from pilot.embedding_engine.string_embedding import StringEmbedding
from pilot.summary.mysql_db_summary import MysqlSummary from pilot.summary.mysql_db_summary import MysqlSummary
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.common.schema import DBType
CFG = Config() CFG = Config()
chat_factory = ChatFactory() chat_factory = ChatFactory()
@ -24,10 +26,13 @@ class DBSummaryClient:
def __init__(self): def __init__(self):
pass pass
def db_summary_embedding(self, dbname): def db_summary_embedding(self, dbname, db_type):
"""put db profile and table profile summary into vector store""" """put db profile and table profile summary into vector store"""
if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None: if DBType.Mysql.value() == db_type:
db_summary_client = MysqlSummary(dbname) db_summary_client = MysqlSummary(dbname)
else:
raise ValueError("Unsupport summary DbType" + db_type)
embeddings = HuggingFaceEmbeddings( embeddings = HuggingFaceEmbeddings(
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
) )
@ -120,10 +125,10 @@ class DBSummaryClient:
return related_table_summaries return related_table_summaries
def init_db_summary(self): def init_db_summary(self):
db = CFG.local_db db_mange = CFG.LOCAL_DB_MANAGE
dbs = db.get_database_list() dbs = db_mange.get_db_list()
for dbname in dbs: for item in dbs:
self.db_summary_embedding(dbname) self.db_summary_embedding(item["db_name"], item["db_type"])
def init_db_profile(self, db_summary_client, dbname, embeddings): def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = { profile_store_config = {

View File

@ -56,8 +56,8 @@ class MysqlSummary(DBSummary):
self.vector_tables_info = [] self.vector_tables_info = []
# self.tables_summary = {} # self.tables_summary = {}
self.db = CFG.local_db self.db = CFG.LOCAL_DB_MANAGE.get_connect()
self.db.get_session(name)
self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format( self.metadata = """user info :{users}, grant info:{grant}, charset:{charset}, collation:{collation}""".format(
users=self.db.get_users(), users=self.db.get_users(),

BIN
pilot/xx.db Normal file

Binary file not shown.