mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
203
dbgpt/agent/resource/database.py
Normal file
203
dbgpt/agent/resource/database.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Database resource module."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import Any, Generic, List, Optional, Tuple, Union
|
||||
|
||||
import cachetools
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.util.cache_utils import cached
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
from .base import P, Resource, ResourceParameters, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_PROMPT_TEMPLATE = (
|
||||
"Database type: {db_type}, related table structure definition: {schemas}"
|
||||
)
|
||||
_DEFAULT_PROMPT_TEMPLATE_ZH = "数据库类型:{db_type},相关表结构定义:{schemas}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DBParameters(ResourceParameters):
|
||||
"""DB parameters class."""
|
||||
|
||||
db_name: str = dataclasses.field(metadata={"help": "DB name"})
|
||||
|
||||
|
||||
class DBResource(Resource[P], Generic[P]):
|
||||
"""Database resource class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
db_type: Optional[str] = None,
|
||||
db_name: Optional[str] = None,
|
||||
dialect: Optional[str] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
prompt_template: str = _DEFAULT_PROMPT_TEMPLATE,
|
||||
):
|
||||
"""Initialize the DB resource."""
|
||||
self._name = name
|
||||
self._db_type = db_type
|
||||
self._db_name = db_name
|
||||
self._dialect = dialect or db_type
|
||||
# Executor for running async tasks
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
self._prompt_template = prompt_template
|
||||
|
||||
@classmethod
|
||||
def type(cls) -> ResourceType:
|
||||
"""Return the resource type."""
|
||||
return ResourceType.DB
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the resource name."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def db_type(self) -> str:
|
||||
"""Return the resource name."""
|
||||
if not self._db_type:
|
||||
raise ValueError("Database type is not set.")
|
||||
return self._db_type
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return the resource name."""
|
||||
if not self._dialect:
|
||||
raise ValueError("Dialect is not set.")
|
||||
return self._dialect
|
||||
|
||||
@cached(cachetools.TTLCache(maxsize=100, ttl=10))
|
||||
async def get_prompt(
|
||||
self,
|
||||
*,
|
||||
lang: str = "en",
|
||||
prompt_type: str = "default",
|
||||
question: Optional[str] = None,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Get the prompt."""
|
||||
if not self._db_name:
|
||||
return "No database name provided."
|
||||
schema_info = await blocking_func_to_async(
|
||||
self._executor, self.get_schema_link, db=self._db_name, question=question
|
||||
)
|
||||
return self._prompt_template.format(db_type=self._db_type, schemas=schema_info)
|
||||
|
||||
def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any:
|
||||
"""Execute the resource."""
|
||||
copy_kwargs = kwargs.copy()
|
||||
if "db" not in copy_kwargs:
|
||||
copy_kwargs["db"] = self._db_name
|
||||
return self._sync_query(*args, **copy_kwargs)
|
||||
|
||||
async def async_execute(
|
||||
self, *args, resource_name: Optional[str] = None, **kwargs
|
||||
) -> Any:
|
||||
"""Execute the resource asynchronously."""
|
||||
copy_kwargs = kwargs.copy()
|
||||
if "db" not in copy_kwargs:
|
||||
copy_kwargs["db"] = self._db_name
|
||||
return await self.query(*args, **copy_kwargs)
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the resource is asynchronous."""
|
||||
return True
|
||||
|
||||
def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def query_to_df(self, sql: str, db: Optional[str] = None):
|
||||
"""Return the query result as a DataFrame."""
|
||||
import pandas as pd
|
||||
|
||||
field_names, result = await self.query(sql, db=db)
|
||||
return pd.DataFrame(result, columns=field_names)
|
||||
|
||||
async def query(self, sql: str, db: Optional[str] = None):
|
||||
"""Return the query result."""
|
||||
db_name = db or self._db_name
|
||||
return await blocking_func_to_async(
|
||||
self._executor, self._sync_query, db=db_name, sql=sql
|
||||
)
|
||||
|
||||
def _sync_query(self, db: str, sql: str):
|
||||
"""Return the query result."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
|
||||
class RDBMSConnectorResource(DBResource[DBParameters]):
|
||||
"""Connector resource class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
connector: Optional[RDBMSConnector] = None,
|
||||
db_name: Optional[str] = None,
|
||||
db_type: Optional[str] = None,
|
||||
dialect: Optional[str] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the connector resource."""
|
||||
if not db_type and connector:
|
||||
db_type = connector.db_type
|
||||
if not dialect and connector:
|
||||
dialect = connector.dialect
|
||||
if not db_name and connector:
|
||||
db_name = connector.get_current_db_name()
|
||||
self._connector = connector
|
||||
super().__init__(
|
||||
name,
|
||||
db_type=db_type,
|
||||
db_name=db_name,
|
||||
dialect=dialect,
|
||||
executor=executor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def connector(self) -> RDBMSConnector:
|
||||
"""Return the connector."""
|
||||
if not self._connector:
|
||||
raise ValueError("Connector is not set.")
|
||||
return self._connector
|
||||
|
||||
def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
return _parse_db_summary(self.connector)
|
||||
|
||||
def _sync_query(self, db: str, sql: str) -> Tuple[Tuple, List]:
|
||||
"""Return the query result."""
|
||||
result_lst = self.connector.run(sql)
|
||||
columns = result_lst[0]
|
||||
values = result_lst[1:]
|
||||
return columns, values
|
||||
|
||||
|
||||
class SQLiteDBResource(RDBMSConnectorResource):
|
||||
"""SQLite database resource class."""
|
||||
|
||||
def __init__(
|
||||
self, name: str, db_name: str, executor: Optional[Executor] = None, **kwargs
|
||||
):
|
||||
"""Initialize the SQLite database resource."""
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector
|
||||
|
||||
conn = SQLiteConnector.from_file_path(db_name)
|
||||
super().__init__(name, conn, executor=executor, **kwargs)
|
Reference in New Issue
Block a user