feat(core): Support higher-order operators (#1984)

Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
Fangyin Cheng
2024-09-09 10:15:37 +08:00
committed by GitHub
parent f6d5fc4595
commit 65c875db20
62 changed files with 6281 additions and 386 deletions

View File

@@ -3,14 +3,41 @@ import logging
from typing import Any, List, Optional, Type, Union, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
from dbgpt.agent.resource.database import (
_DEFAULT_PROMPT_TEMPLATE,
_DEFAULT_PROMPT_TEMPLATE_ZH,
DBParameters,
RDBMSConnectorResource,
)
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
FunctionDynamicOptions,
OptionValue,
Parameter,
ResourceCategory,
register_resource,
)
from dbgpt.util import ParameterDescription
from dbgpt.util.i18n_utils import _
CFG = Config()
logger = logging.getLogger(__name__)
def _load_datasource() -> List[OptionValue]:
dbs = CFG.local_db_manager.get_db_list()
results = [
OptionValue(
label="[" + db["db_type"] + "]" + db["db_name"],
name=db["db_name"],
value=db["db_name"],
)
for db in dbs
]
return results
@dataclasses.dataclass
class DatasourceDBParameters(DBParameters):
"""The DB parameters for the datasource."""
@@ -57,6 +84,44 @@ class DatasourceDBParameters(DBParameters):
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
@register_resource(
_("Datasource Resource"),
"datasource",
category=ResourceCategory.DATABASE,
description=_(
"Connect to a datasource(retrieve table schemas and execute SQL to fetch data)."
),
tags={"order": TAGS_ORDER_HIGH},
parameters=[
Parameter.build_from(
_("Datasource Name"),
"name",
str,
optional=True,
default="datasource",
description=_("The name of the datasource, default is 'datasource'."),
),
Parameter.build_from(
_("DB Name"),
"db_name",
str,
description=_("The name of the database."),
options=FunctionDynamicOptions(func=_load_datasource),
),
Parameter.build_from(
_("Prompt Template"),
"prompt_template",
str,
optional=True,
default=(
_DEFAULT_PROMPT_TEMPLATE_ZH
if CFG.LANGUAGE == "zh"
else _DEFAULT_PROMPT_TEMPLATE
),
description=_("The prompt template to build a database prompt."),
),
],
)
class DatasourceResource(RDBMSConnectorResource):
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
conn = CFG.local_db_manager.get_connector(db_name)