import dataclasses import logging from typing import Any, List, Optional, Type, Union, cast from dbgpt._private.config import Config from dbgpt.agent.resource.database import ( _DEFAULT_PROMPT_TEMPLATE, _DEFAULT_PROMPT_TEMPLATE_ZH, DBParameters, RDBMSConnectorResource, ) from dbgpt.core.awel.flow import ( TAGS_ORDER_HIGH, FunctionDynamicOptions, OptionValue, Parameter, ResourceCategory, register_resource, ) from dbgpt.util import ParameterDescription from dbgpt.util.i18n_utils import _ CFG = Config() logger = logging.getLogger(__name__) def _load_datasource() -> List[OptionValue]: dbs = CFG.local_db_manager.get_db_list() results = [ OptionValue( label="[" + db["db_type"] + "]" + db["db_name"], name=db["db_name"], value=db["db_name"], ) for db in dbs ] return results @dataclasses.dataclass class DatasourceDBParameters(DBParameters): """The DB parameters for the datasource.""" db_name: str = dataclasses.field(metadata={"help": "DB name"}) @classmethod def _resource_version(cls) -> str: """Return the resource version.""" return "v1" @classmethod def to_configurations( cls, parameters: Type["DatasourceDBParameters"], version: Optional[str] = None, **kwargs, ) -> Any: """Convert the parameters to configurations.""" conf: List[ParameterDescription] = cast( List[ParameterDescription], super().to_configurations( parameters, **kwargs, ), ) version = version or cls._resource_version() if version != "v1": return conf # Compatible with old version for param in conf: if param.param_name == "db_name": return param.valid_values or [] return [] @classmethod def from_dict( cls, data: dict, ignore_extra_fields: bool = True ) -> "DatasourceDBParameters": """Create a new instance from a dictionary.""" copied_data = data.copy() if "db_name" not in copied_data and "value" in copied_data: copied_data["db_name"] = copied_data.pop("value") return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields) @register_resource( _("Datasource Resource"), "datasource", category=ResourceCategory.DATABASE, description=_( "Connect to a datasource(retrieve table schemas and execute SQL to fetch data)." ), tags={"order": TAGS_ORDER_HIGH}, parameters=[ Parameter.build_from( _("Datasource Name"), "name", str, optional=True, default="datasource", description=_("The name of the datasource, default is 'datasource'."), ), Parameter.build_from( _("DB Name"), "db_name", str, description=_("The name of the database."), options=FunctionDynamicOptions(func=_load_datasource), ), Parameter.build_from( _("Prompt Template"), "prompt_template", str, optional=True, default=( _DEFAULT_PROMPT_TEMPLATE_ZH if CFG.LANGUAGE == "zh" else _DEFAULT_PROMPT_TEMPLATE ), description=_("The prompt template to build a database prompt."), ), ], ) class DatasourceResource(RDBMSConnectorResource): def __init__(self, name: str, db_name: Optional[str] = None, **kwargs): conn = CFG.local_db_manager.get_connector(db_name) super().__init__(name, connector=conn, db_name=db_name, **kwargs) @classmethod def resource_parameters_class(cls, **kwargs) -> Type[DatasourceDBParameters]: dbs = CFG.local_db_manager.get_db_list(user_id=kwargs.get("user_id", None)) results = [ { "label": "[" + db["db_type"] + "]" + db["db_name"], "key": db["db_name"], "description": db["comment"], } for db in dbs ] @dataclasses.dataclass class _DynDBParameters(DatasourceDBParameters): db_name: str = dataclasses.field( metadata={"help": "DB name", "valid_values": results} ) return _DynDBParameters def get_schema_link( self, db: str, question: Optional[str] = None ) -> Union[str, List[str]]: """Return the schema link of the database.""" try: from dbgpt.rag.summary.db_summary_client import DBSummaryClient except ImportError: raise ValueError("Could not import DBSummaryClient. ") client = DBSummaryClient(system_app=CFG.SYSTEM_APP) table_infos = None try: table_infos = client.get_db_summary( db, question, CFG.KNOWLEDGE_SEARCH_TOP_SIZE, ) except Exception as e: logger.warning(f"db summary find error!{str(e)}") if not table_infos: conn = CFG.local_db_manager.get_connector(db) table_infos = conn.table_simple_info() return table_infos