mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 11:31:12 +00:00
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
@@ -3,14 +3,41 @@ import logging
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
|
||||
from dbgpt.agent.resource.database import (
|
||||
_DEFAULT_PROMPT_TEMPLATE,
|
||||
_DEFAULT_PROMPT_TEMPLATE_ZH,
|
||||
DBParameters,
|
||||
RDBMSConnectorResource,
|
||||
)
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
FunctionDynamicOptions,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ResourceCategory,
|
||||
register_resource,
|
||||
)
|
||||
from dbgpt.util import ParameterDescription
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_datasource() -> List[OptionValue]:
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
results = [
|
||||
OptionValue(
|
||||
label="[" + db["db_type"] + "]" + db["db_name"],
|
||||
name=db["db_name"],
|
||||
value=db["db_name"],
|
||||
)
|
||||
for db in dbs
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DatasourceDBParameters(DBParameters):
|
||||
"""The DB parameters for the datasource."""
|
||||
@@ -57,6 +84,44 @@ class DatasourceDBParameters(DBParameters):
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Datasource Resource"),
|
||||
"datasource",
|
||||
category=ResourceCategory.DATABASE,
|
||||
description=_(
|
||||
"Connect to a datasource(retrieve table schemas and execute SQL to fetch data)."
|
||||
),
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Datasource Name"),
|
||||
"name",
|
||||
str,
|
||||
optional=True,
|
||||
default="datasource",
|
||||
description=_("The name of the datasource, default is 'datasource'."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("DB Name"),
|
||||
"db_name",
|
||||
str,
|
||||
description=_("The name of the database."),
|
||||
options=FunctionDynamicOptions(func=_load_datasource),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
str,
|
||||
optional=True,
|
||||
default=(
|
||||
_DEFAULT_PROMPT_TEMPLATE_ZH
|
||||
if CFG.LANGUAGE == "zh"
|
||||
else _DEFAULT_PROMPT_TEMPLATE
|
||||
),
|
||||
description=_("The prompt template to build a database prompt."),
|
||||
),
|
||||
],
|
||||
)
|
||||
class DatasourceResource(RDBMSConnectorResource):
|
||||
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
|
||||
conn = CFG.local_db_manager.get_connector(db_name)
|
||||
|
Reference in New Issue
Block a user