mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
Multi DB support
This commit is contained in:
parent
35536df73e
commit
23dadef155
@ -46,12 +46,12 @@ class DuckdbConnectConfig:
|
||||
except Exception as e:
|
||||
raise "Unusable duckdb database path:" + path
|
||||
|
||||
def add_file_db(self, db_name, db_type, db_path: str):
|
||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||
[db_name, db_type, db_path, "", "", "", "", ""],
|
||||
[db_name, db_type, db_path, "", "", "", "", comment],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
@ -6,16 +6,17 @@ from pilot.connections.base import BaseConnect
|
||||
|
||||
from pilot.connections.rdbms.conn_mysql import MySQLConnect
|
||||
from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
|
||||
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.connections.db_conn_info import DBConfig
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ConnectManager:
|
||||
|
||||
|
||||
def get_all_subclasses(self, cls):
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
@ -93,5 +94,17 @@ class ConnectManager:
|
||||
def get_db_list(self):
|
||||
return self.storage.get_db_list()
|
||||
|
||||
|
||||
def get_db_names(self):
|
||||
return self.storage.get_db_names()
|
||||
|
||||
def delete_db(self, db_name: str):
|
||||
return self.storage.delete_db(db_name)
|
||||
|
||||
def add_db(self, db_info: DBConfig):
|
||||
db_type = DBType.of_db_type(db_info.db_type)
|
||||
if db_type.is_file_db():
|
||||
self.storage.add_file_db(db_info.db_name, db_info.db_type, db_info.file_path)
|
||||
else:
|
||||
self.storage.add_url_db(db_info.db_name, db_info.db_type, db_info.db_host, db_info.db_port, db_info.db_user, db_info.db_pwd, db_info.comment)
|
||||
return True
|
@ -10,10 +10,6 @@ from sqlalchemy import (
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
Base = declarative_base()
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
"""Connect Duckdb Database fetch MetaData
|
||||
|
48
pilot/connections/rdbms/conn_mssql.py
Normal file
48
pilot/connections/rdbms/conn_mssql.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any, Iterable
|
||||
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
"""Connect MSSQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "mssql"
|
||||
db_dialect: str = "mssql"
|
||||
driver: str = "mssql+pymssql"
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"]
|
||||
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results =[]
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table_name}'
|
||||
"""
|
||||
cursor_colums = self.session.execute(text(_sql))
|
||||
colum_results = cursor_colums.fetchall()
|
||||
table_colums = []
|
||||
for row_col in colum_results:
|
||||
field_info = list(row_col)
|
||||
table_colums.append(field_info[0])
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
@ -1,18 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any
|
||||
|
||||
from pilot.connections.rdbms.rdbms_connect import RDBMSDatabase
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
"""Connect MSSQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "mssql"
|
||||
db_dialect: str = "mssql"
|
||||
driver: str = "pyodbc"
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource"]
|
@ -23,8 +23,9 @@ from pilot.openapi.api_v1.api_view_model import (
|
||||
Result,
|
||||
ConversationVo,
|
||||
MessageVo,
|
||||
ChatSceneVo,
|
||||
ChatSceneVo
|
||||
)
|
||||
from pilot.connections.db_conn_info import DBConfig
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.knowledge.service import KnowledgeService
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
@ -95,6 +96,25 @@ def knowledge_list():
|
||||
return params
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/v1/chat/db/list", response_model=Result[DBConfig])
|
||||
async def dialogue_list():
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.get_db_list())
|
||||
|
||||
@router.post("/v1/chat/db/add", response_model=Result[bool])
|
||||
async def dialogue_list(db_config: DBConfig = Body() ):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.add_db(db_config))
|
||||
|
||||
@router.post("/v1/chat/db/delete", response_model=Result[bool])
|
||||
async def dialogue_list(db_name: str = None):
|
||||
return Result.succ(CFG.LOCAL_DB_MANAGE.delete_db(db_name))
|
||||
|
||||
|
||||
@router.get("/v1/chat/db/support/type", response_model=Result[str])
|
||||
async def db_support_types():
|
||||
return Result[str].succ(["mysql", "mssql", "duckdb"])
|
||||
|
||||
@router.get("/v1/chat/dialogue/list", response_model=Result[ConversationVo])
|
||||
async def dialogue_list(user_id: str = None):
|
||||
dialogues: List = []
|
||||
|
@ -73,6 +73,7 @@ weaviate-client
|
||||
pymysql
|
||||
duckdb
|
||||
duckdb-engine
|
||||
pymssql
|
||||
|
||||
# Testing dependencies
|
||||
pytest
|
||||
|
Loading…
Reference in New Issue
Block a user