DB-GPT/dbgpt/serve/agent/resource/datasource.py
2024-08-29 19:39:42 +08:00

173 lines
5.2 KiB
Python

import dataclasses
import logging
from typing import Any, List, Optional, Type, Union, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.database import (
_DEFAULT_PROMPT_TEMPLATE,
_DEFAULT_PROMPT_TEMPLATE_ZH,
DBParameters,
RDBMSConnectorResource,
)
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
FunctionDynamicOptions,
OptionValue,
Parameter,
ResourceCategory,
register_resource,
)
from dbgpt.util import ParameterDescription
from dbgpt.util.i18n_utils import _
CFG = Config()
logger = logging.getLogger(__name__)
def _load_datasource() -> List[OptionValue]:
dbs = CFG.local_db_manager.get_db_list()
results = [
OptionValue(
label="[" + db["db_type"] + "]" + db["db_name"],
name=db["db_name"],
value=db["db_name"],
)
for db in dbs
]
return results
@dataclasses.dataclass
class DatasourceDBParameters(DBParameters):
"""The DB parameters for the datasource."""
db_name: str = dataclasses.field(metadata={"help": "DB name"})
@classmethod
def _resource_version(cls) -> str:
"""Return the resource version."""
return "v1"
@classmethod
def to_configurations(
cls,
parameters: Type["DatasourceDBParameters"],
version: Optional[str] = None,
**kwargs,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription],
super().to_configurations(
parameters,
**kwargs,
),
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "db_name":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> "DatasourceDBParameters":
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "db_name" not in copied_data and "value" in copied_data:
copied_data["db_name"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
@register_resource(
_("Datasource Resource"),
"datasource",
category=ResourceCategory.DATABASE,
description=_(
"Connect to a datasource(retrieve table schemas and execute SQL to fetch data)."
),
tags={"order": TAGS_ORDER_HIGH},
parameters=[
Parameter.build_from(
_("Datasource Name"),
"name",
str,
optional=True,
default="datasource",
description=_("The name of the datasource, default is 'datasource'."),
),
Parameter.build_from(
_("DB Name"),
"db_name",
str,
description=_("The name of the database."),
options=FunctionDynamicOptions(func=_load_datasource),
),
Parameter.build_from(
_("Prompt Template"),
"prompt_template",
str,
optional=True,
default=(
_DEFAULT_PROMPT_TEMPLATE_ZH
if CFG.LANGUAGE == "zh"
else _DEFAULT_PROMPT_TEMPLATE
),
description=_("The prompt template to build a database prompt."),
),
],
)
class DatasourceResource(RDBMSConnectorResource):
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
conn = CFG.local_db_manager.get_connector(db_name)
super().__init__(name, connector=conn, db_name=db_name, **kwargs)
@classmethod
def resource_parameters_class(cls, **kwargs) -> Type[DatasourceDBParameters]:
dbs = CFG.local_db_manager.get_db_list(user_id=kwargs.get("user_id", None))
results = [
{
"label": "[" + db["db_type"] + "]" + db["db_name"],
"key": db["db_name"],
"description": db["comment"],
}
for db in dbs
]
@dataclasses.dataclass
class _DynDBParameters(DatasourceDBParameters):
db_name: str = dataclasses.field(
metadata={"help": "DB name", "valid_values": results}
)
return _DynDBParameters
def get_schema_link(
self, db: str, question: Optional[str] = None
) -> Union[str, List[str]]:
"""Return the schema link of the database."""
try:
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try:
table_infos = client.get_db_summary(
db,
question,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e:
logger.warning(f"db summary find error!{str(e)}")
if not table_infos:
conn = CFG.local_db_manager.get_connector(db)
table_infos = conn.table_simple_info()
return table_infos