mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 09:28:42 +00:00
128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
"""Database resource client API."""
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
|
|
|
|
from .resource_api import AgentResource, ResourceClient, ResourceType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ResourceDbClient(ResourceClient):
|
|
"""Database resource client API."""
|
|
|
|
@property
|
|
def type(self):
|
|
"""Return the resource type."""
|
|
return ResourceType.DB
|
|
|
|
def get_data_type(self, resource: AgentResource) -> str:
|
|
"""Return the data type of the resource."""
|
|
return super().get_data_type(resource)
|
|
|
|
async def get_data_introduce(
|
|
self, resource: AgentResource, question: Optional[str] = None
|
|
) -> Union[str, List[str]]:
|
|
"""Return the data introduce of the resource."""
|
|
return await self.get_schema_link(resource.value, question)
|
|
|
|
async def get_schema_link(
|
|
self, db: str, question: Optional[str] = None
|
|
) -> Union[str, List[str]]:
|
|
"""Return the schema link of the database."""
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
async def query_to_df(self, dbe: str, sql: str):
|
|
"""Return the query result as a DataFrame."""
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
async def query(self, db: str, sql: str):
|
|
"""Return the query result."""
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
async def run_sql(self, db: str, sql: str):
|
|
"""Run the SQL."""
|
|
raise NotImplementedError("The run method should be implemented in a subclass.")
|
|
|
|
|
|
class SqliteLoadClient(ResourceDbClient):
|
|
"""SQLite resource client."""
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.orm.session import Session
|
|
|
|
def __init__(self):
|
|
"""Create a SQLite resource client."""
|
|
super(SqliteLoadClient, self).__init__()
|
|
|
|
def get_data_type(self, resource: AgentResource) -> str:
|
|
"""Return the data type of the resource."""
|
|
return "sqlite"
|
|
|
|
@contextmanager
|
|
def connect(self, db) -> Iterator["Session"]:
|
|
"""Connect to the database."""
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
engine = create_engine("sqlite:///" + db, echo=True)
|
|
Session = sessionmaker(bind=engine)
|
|
session = Session()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
async def get_schema_link(
|
|
self, db: str, question: Optional[str] = None
|
|
) -> Union[str, List[str]]:
|
|
"""Return the schema link of the database."""
|
|
from sqlalchemy import text
|
|
|
|
with self.connect(db) as connect:
|
|
_tables_sql = """
|
|
SELECT name FROM sqlite_master WHERE type='table'
|
|
"""
|
|
cursor = connect.execute(text(_tables_sql))
|
|
tables_results = cursor.fetchall()
|
|
results = []
|
|
for row in tables_results:
|
|
table_name = row[0]
|
|
_sql = f"""
|
|
PRAGMA table_info({table_name})
|
|
"""
|
|
cursor_colums = connect.execute(text(_sql))
|
|
colum_results = cursor_colums.fetchall()
|
|
table_colums = []
|
|
for row_col in colum_results:
|
|
field_info = list(row_col)
|
|
table_colums.append(field_info[1])
|
|
|
|
results.append(f"{table_name}({','.join(table_colums)});")
|
|
return results
|
|
|
|
async def query_to_df(self, db: str, sql: str):
|
|
"""Return the query result as a DataFrame."""
|
|
import pandas as pd
|
|
|
|
field_names, result = await self.query(db, sql)
|
|
return pd.DataFrame(result, columns=field_names)
|
|
|
|
async def query(self, db: str, sql: str):
|
|
"""Return the query result."""
|
|
from sqlalchemy import text
|
|
|
|
with self.connect(db) as connect:
|
|
logger.info(f"Query[{sql}]")
|
|
if not sql:
|
|
return []
|
|
cursor = connect.execute(text(sql))
|
|
if cursor.returns_rows: # type: ignore
|
|
result = cursor.fetchall()
|
|
field_names = tuple(i[0:] for i in cursor.keys())
|
|
return field_names, result
|