mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(core): Support simple DB query for sdk (#917)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
@@ -2,58 +2,104 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
"""We need to design a base class. That other connector can Write with this"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from abc import ABC
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
|
||||
class BaseConnect(ABC):
|
||||
def get_connect(self, db_name: str):
|
||||
pass
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get all table names"""
|
||||
pass
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get table info about specified table.
|
||||
|
||||
Returns:
|
||||
str: Table information joined by '\n\n'
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get index info about specified table.
|
||||
|
||||
Args:
|
||||
table_names (Optional[List[str]]): table names
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_example_data(self, table: str, count: int = 3):
|
||||
"""Get example data about specified table.
|
||||
|
||||
Not used now.
|
||||
|
||||
Args:
|
||||
table (str): table name
|
||||
count (int): example data count
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_database_list(self):
|
||||
def get_database_list(self) -> List[str]:
|
||||
"""Get database list.
|
||||
|
||||
Returns:
|
||||
List[str]: database list
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
pass
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
"""Get table comments.
|
||||
|
||||
Args:
|
||||
db_name (str): database name
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute sql command.
|
||||
|
||||
Args:
|
||||
command (str): sql command
|
||||
fetch (str): fetch type
|
||||
"""
|
||||
pass
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
"""Execute sql command and return dataframe."""
|
||||
pass
|
||||
|
||||
def get_users(self):
|
||||
pass
|
||||
"""Get user info."""
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
pass
|
||||
"""Get grant info."""
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
pass
|
||||
"""Get collation."""
|
||||
return None
|
||||
|
||||
def get_charset(self):
|
||||
pass
|
||||
def get_charset(self) -> str:
|
||||
"""Get character_set of current database."""
|
||||
return "utf-8"
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
pass
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get the creation table sql about specified table."""
|
||||
pass
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def is_normal_type(cls) -> bool:
|
||||
"""Return whether the connector is a normal type."""
|
||||
return True
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
|
||||
class DBConfig(BaseModel):
|
||||
|
@@ -1,3 +1,4 @@
|
||||
from typing import List, Type
|
||||
from dbgpt.datasource import ConnectConfigDao
|
||||
from dbgpt.storage.schema import DBType
|
||||
from dbgpt.component import SystemApp, ComponentType
|
||||
@@ -21,7 +22,7 @@ from dbgpt.datasource.rdbms.conn_doris import DorisConnect
|
||||
class ConnectManager:
|
||||
"""db connect manager"""
|
||||
|
||||
def get_all_subclasses(self, cls):
|
||||
def get_all_subclasses(self, cls: Type[BaseConnect]) -> List[Type[BaseConnect]]:
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += self.get_all_subclasses(subclass)
|
||||
@@ -31,7 +32,7 @@ class ConnectManager:
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
support_types = []
|
||||
for cls in chat_classes:
|
||||
if cls.db_type:
|
||||
if cls.db_type and cls.is_normal_type():
|
||||
support_types.append(DBType.of_db_type(cls.db_type))
|
||||
return support_types
|
||||
|
||||
@@ -39,7 +40,7 @@ class ConnectManager:
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
result = None
|
||||
for cls in chat_classes:
|
||||
if cls.db_type == db_type:
|
||||
if cls.db_type == db_type and cls.is_normal_type():
|
||||
result = cls
|
||||
if not result:
|
||||
raise ValueError("Unsupported Db Type!" + db_type)
|
||||
|
0
dbgpt/datasource/operator/__init__.py
Normal file
0
dbgpt/datasource/operator/__init__.py
Normal file
16
dbgpt/datasource/operator/datasource_operator.py
Normal file
16
dbgpt/datasource/operator/datasource_operator.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Any
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.task.base import IN, OUT
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
|
||||
|
||||
class DatasourceOperator(MapOperator[str, Any]):
|
||||
def __init__(self, connection: BaseConnect, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
|
||||
async def map(self, input_value: IN) -> OUT:
|
||||
return await self.blocking_func_to_async(self.query, input_value)
|
||||
|
||||
def query(self, input_value: str) -> Any:
|
||||
return self._connection.run_to_df(input_value)
|
@@ -4,9 +4,12 @@
|
||||
import os
|
||||
from typing import Optional, Any, Iterable
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
import tempfile
|
||||
import logging
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLiteConnect(RDBMSDatabase):
|
||||
"""Connect SQLite Database fetch MetaData
|
||||
@@ -127,3 +130,116 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
||||
|
||||
|
||||
class SQLiteTempConnect(SQLiteConnect):
|
||||
"""A temporary SQLite database connection. The database file will be deleted when the connection is closed."""
|
||||
|
||||
def __init__(self, engine, temp_file_path, *args, **kwargs):
|
||||
super().__init__(engine, *args, **kwargs)
|
||||
self.temp_file_path = temp_file_path
|
||||
self._is_closed = False
|
||||
|
||||
@classmethod
|
||||
def create_temporary_db(
|
||||
cls, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> "SQLiteTempConnect":
|
||||
"""Create a temporary SQLite database with a temporary file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
with SQLiteTempConnect.create_temporary_db() as db:
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run(db.session, "insert into test(id) values (1)")
|
||||
db.run(db.session, "insert into test(id) values (2)")
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,), (2,)]
|
||||
|
||||
Args:
|
||||
engine_args (Optional[dict]): SQLAlchemy engine arguments.
|
||||
|
||||
Returns:
|
||||
SQLiteTempConnect: A SQLiteTempConnect instance.
|
||||
"""
|
||||
_engine_args = engine_args or {}
|
||||
_engine_args["connect_args"] = {"check_same_thread": False}
|
||||
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
|
||||
engine = create_engine(f"sqlite:///{temp_file_path}", **_engine_args)
|
||||
return cls(engine, temp_file_path, **kwargs)
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
if not self._is_closed:
|
||||
if self._engine:
|
||||
self._engine.dispose()
|
||||
try:
|
||||
if os.path.exists(self.temp_file_path):
|
||||
os.remove(self.temp_file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing temporary database file: {e}")
|
||||
self._is_closed = True
|
||||
|
||||
def create_temp_tables(self, tables_info):
|
||||
"""Create temporary tables with data.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
tables_info = {
|
||||
"test": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 20),
|
||||
(2, "Jack", 21),
|
||||
(3, "Alice", 22),
|
||||
],
|
||||
},
|
||||
}
|
||||
with SQLiteTempConnect.create_temporary_db() as db:
|
||||
db.create_temp_tables(tables_info)
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
assert field_names == ["id", "name", "age"]
|
||||
assert result == [(1, "Tom", 20), (2, "Jack", 21), (3, "Alice", 22)]
|
||||
|
||||
Args:
|
||||
tables_info (dict): A dictionary of table information.
|
||||
"""
|
||||
for table_name, table_data in tables_info.items():
|
||||
columns = ", ".join(
|
||||
[f"{col} {dtype}" for col, dtype in table_data["columns"].items()]
|
||||
)
|
||||
create_sql = f"CREATE TABLE {table_name} ({columns});"
|
||||
self.session.execute(text(create_sql))
|
||||
for row in table_data.get("data", []):
|
||||
placeholders = ", ".join(
|
||||
[":param" + str(index) for index, _ in enumerate(row)]
|
||||
)
|
||||
insert_sql = f"INSERT INTO {table_name} VALUES ({placeholders});"
|
||||
|
||||
param_dict = {
|
||||
"param" + str(index): value for index, value in enumerate(row)
|
||||
}
|
||||
self.session.execute(text(insert_sql), param_dict)
|
||||
self.session.commit()
|
||||
self._sync_tables_from_db()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@classmethod
|
||||
def is_normal_type(cls) -> bool:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user