mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 14:57:35 +00:00
173 lines
5.2 KiB
Python
173 lines
5.2 KiB
Python
import dataclasses
|
|
import logging
|
|
from typing import Any, List, Optional, Type, Union, cast
|
|
|
|
from dbgpt._private.config import Config
|
|
from dbgpt.agent.resource.database import (
|
|
_DEFAULT_PROMPT_TEMPLATE,
|
|
_DEFAULT_PROMPT_TEMPLATE_ZH,
|
|
DBParameters,
|
|
RDBMSConnectorResource,
|
|
)
|
|
from dbgpt.core.awel.flow import (
|
|
TAGS_ORDER_HIGH,
|
|
FunctionDynamicOptions,
|
|
OptionValue,
|
|
Parameter,
|
|
ResourceCategory,
|
|
register_resource,
|
|
)
|
|
from dbgpt.util import ParameterDescription
|
|
from dbgpt.util.i18n_utils import _
|
|
|
|
CFG = Config()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _load_datasource() -> List[OptionValue]:
|
|
dbs = CFG.local_db_manager.get_db_list()
|
|
results = [
|
|
OptionValue(
|
|
label="[" + db["db_type"] + "]" + db["db_name"],
|
|
name=db["db_name"],
|
|
value=db["db_name"],
|
|
)
|
|
for db in dbs
|
|
]
|
|
return results
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DatasourceDBParameters(DBParameters):
|
|
"""The DB parameters for the datasource."""
|
|
|
|
db_name: str = dataclasses.field(metadata={"help": "DB name"})
|
|
|
|
@classmethod
|
|
def _resource_version(cls) -> str:
|
|
"""Return the resource version."""
|
|
return "v1"
|
|
|
|
@classmethod
|
|
def to_configurations(
|
|
cls,
|
|
parameters: Type["DatasourceDBParameters"],
|
|
version: Optional[str] = None,
|
|
**kwargs,
|
|
) -> Any:
|
|
"""Convert the parameters to configurations."""
|
|
conf: List[ParameterDescription] = cast(
|
|
List[ParameterDescription],
|
|
super().to_configurations(
|
|
parameters,
|
|
**kwargs,
|
|
),
|
|
)
|
|
version = version or cls._resource_version()
|
|
if version != "v1":
|
|
return conf
|
|
# Compatible with old version
|
|
for param in conf:
|
|
if param.param_name == "db_name":
|
|
return param.valid_values or []
|
|
return []
|
|
|
|
@classmethod
|
|
def from_dict(
|
|
cls, data: dict, ignore_extra_fields: bool = True
|
|
) -> "DatasourceDBParameters":
|
|
"""Create a new instance from a dictionary."""
|
|
copied_data = data.copy()
|
|
if "db_name" not in copied_data and "value" in copied_data:
|
|
copied_data["db_name"] = copied_data.pop("value")
|
|
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
|
|
|
|
|
@register_resource(
|
|
_("Datasource Resource"),
|
|
"datasource",
|
|
category=ResourceCategory.DATABASE,
|
|
description=_(
|
|
"Connect to a datasource(retrieve table schemas and execute SQL to fetch data)."
|
|
),
|
|
tags={"order": TAGS_ORDER_HIGH},
|
|
parameters=[
|
|
Parameter.build_from(
|
|
_("Datasource Name"),
|
|
"name",
|
|
str,
|
|
optional=True,
|
|
default="datasource",
|
|
description=_("The name of the datasource, default is 'datasource'."),
|
|
),
|
|
Parameter.build_from(
|
|
_("DB Name"),
|
|
"db_name",
|
|
str,
|
|
description=_("The name of the database."),
|
|
options=FunctionDynamicOptions(func=_load_datasource),
|
|
),
|
|
Parameter.build_from(
|
|
_("Prompt Template"),
|
|
"prompt_template",
|
|
str,
|
|
optional=True,
|
|
default=(
|
|
_DEFAULT_PROMPT_TEMPLATE_ZH
|
|
if CFG.LANGUAGE == "zh"
|
|
else _DEFAULT_PROMPT_TEMPLATE
|
|
),
|
|
description=_("The prompt template to build a database prompt."),
|
|
),
|
|
],
|
|
)
|
|
class DatasourceResource(RDBMSConnectorResource):
|
|
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
|
|
conn = CFG.local_db_manager.get_connector(db_name)
|
|
super().__init__(name, connector=conn, db_name=db_name, **kwargs)
|
|
|
|
@classmethod
|
|
def resource_parameters_class(cls, **kwargs) -> Type[DatasourceDBParameters]:
|
|
dbs = CFG.local_db_manager.get_db_list(user_id=kwargs.get("user_id", None))
|
|
results = [
|
|
{
|
|
"label": "[" + db["db_type"] + "]" + db["db_name"],
|
|
"key": db["db_name"],
|
|
"description": db["comment"],
|
|
}
|
|
for db in dbs
|
|
]
|
|
|
|
@dataclasses.dataclass
|
|
class _DynDBParameters(DatasourceDBParameters):
|
|
db_name: str = dataclasses.field(
|
|
metadata={"help": "DB name", "valid_values": results}
|
|
)
|
|
|
|
return _DynDBParameters
|
|
|
|
def get_schema_link(
|
|
self, db: str, question: Optional[str] = None
|
|
) -> Union[str, List[str]]:
|
|
"""Return the schema link of the database."""
|
|
try:
|
|
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
|
except ImportError:
|
|
raise ValueError("Could not import DBSummaryClient. ")
|
|
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
|
table_infos = None
|
|
try:
|
|
table_infos = client.get_db_summary(
|
|
db,
|
|
question,
|
|
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"db summary find error!{str(e)}")
|
|
if not table_infos:
|
|
conn = CFG.local_db_manager.get_connector(db)
|
|
table_infos = conn.table_simple_info()
|
|
|
|
return table_infos
|