refactor(agent): Refactor resource of agents (#1518)

This commit is contained in:
Fangyin Cheng
2024-05-15 09:57:19 +08:00
committed by GitHub
parent db4d318a5f
commit 559affe87d
102 changed files with 2633 additions and 2549 deletions

View File

View File

@@ -0,0 +1,95 @@
import dataclasses
import logging
from typing import Any, List, Optional, Type, Union, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
from dbgpt.util import ParameterDescription
CFG = Config()
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class DatasourceDBParameters(DBParameters):
"""The DB parameters for the datasource."""
db_name: str = dataclasses.field(metadata={"help": "DB name"})
@classmethod
def _resource_version(cls) -> str:
"""Return the resource version."""
return "v1"
@classmethod
def to_configurations(
cls,
parameters: Type["DatasourceDBParameters"],
version: Optional[str] = None,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription], super().to_configurations(parameters)
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "db_name":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> "DatasourceDBParameters":
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "db_name" not in copied_data and "value" in copied_data:
copied_data["db_name"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
class DatasourceResource(RDBMSConnectorResource):
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
conn = CFG.local_db_manager.get_connector(db_name)
super().__init__(name, connector=conn, db_name=db_name, **kwargs)
@classmethod
def resource_parameters_class(cls) -> Type[DatasourceDBParameters]:
dbs = CFG.local_db_manager.get_db_list()
results = [db["db_name"] for db in dbs]
@dataclasses.dataclass
class _DynDBParameters(DatasourceDBParameters):
db_name: str = dataclasses.field(
metadata={"help": "DB name", "valid_values": results}
)
return _DynDBParameters
def get_schema_link(
self, db: str, question: Optional[str] = None
) -> Union[str, List[str]]:
"""Return the schema link of the database."""
try:
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try:
table_infos = client.get_db_summary(
db,
question,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e:
logger.warning(f"db summary find error!{str(e)}")
if not table_infos:
conn = CFG.local_db_manager.get_connector(db)
table_infos = conn.table_simple_info()
return table_infos

View File

@@ -0,0 +1,89 @@
import dataclasses
import logging
from typing import Any, List, Optional, Type, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.knowledge import (
RetrieverResource,
RetrieverResourceParameters,
)
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt.util import ParameterDescription
CFG = Config()
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class KnowledgeSpaceLoadResourceParameters(RetrieverResourceParameters):
space_name: str = dataclasses.field(
default=None, metadata={"help": "Knowledge space name"}
)
@classmethod
def _resource_version(cls) -> str:
"""Return the resource version."""
return "v1"
@classmethod
def to_configurations(
cls,
parameters: Type["KnowledgeSpaceLoadResourceParameters"],
version: Optional[str] = None,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription], super().to_configurations(parameters)
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "space_name":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> "KnowledgeSpaceLoadResourceParameters":
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "space_name" not in copied_data and "value" in copied_data:
copied_data["space_name"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
class KnowledgeSpaceRetrieverResource(RetrieverResource):
"""Knowledge Space retriever resource."""
def __init__(self, name: str, space_name: str):
retriever = KnowledgeSpaceRetriever(space_name=space_name)
super().__init__(name, retriever=retriever)
@classmethod
def resource_parameters_class(cls) -> Type[KnowledgeSpaceLoadResourceParameters]:
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.knowledge.service import KnowledgeService
knowledge_space_service = KnowledgeService()
knowledge_spaces = knowledge_space_service.get_knowledge_space(
KnowledgeSpaceRequest()
)
results = [ks.name for ks in knowledge_spaces]
@dataclasses.dataclass
class _DynamicKnowledgeSpaceLoadResourceParameters(
KnowledgeSpaceLoadResourceParameters
):
space_name: str = dataclasses.field(
default=None,
metadata={
"help": "Knowledge space name",
"valid_values": results,
},
)
return _DynamicKnowledgeSpaceLoadResourceParameters

View File

@@ -0,0 +1,92 @@
import dataclasses
import logging
from typing import Any, List, Optional, Type, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.pack import PackResourceParameters
from dbgpt.agent.resource.tool.pack import ToolPack
from dbgpt.component import ComponentType
from dbgpt.serve.agent.hub.controller import ModulePlugin
from dbgpt.util.parameter_utils import ParameterDescription
CFG = Config()
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class PluginPackResourceParameters(PackResourceParameters):
tool_name: str = dataclasses.field(metadata={"help": "Tool name"})
@classmethod
def _resource_version(cls) -> str:
"""Return the resource version."""
return "v1"
@classmethod
def to_configurations(
cls,
parameters: Type["PluginPackResourceParameters"],
version: Optional[str] = None,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription], super().to_configurations(parameters)
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "tool_name":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> "PluginPackResourceParameters":
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "tool_name" not in copied_data and "value" in copied_data:
copied_data["tool_name"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
class PluginToolPack(ToolPack):
def __init__(self, tool_name: str, **kwargs):
kwargs.pop("name")
super().__init__([], name="Plugin Tool Pack", **kwargs)
# Select tool name
self._tool_name = tool_name
@classmethod
def type_alias(cls) -> str:
return "tool(autogpt_plugins)"
@classmethod
def resource_parameters_class(cls) -> Type[PluginPackResourceParameters]:
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
ComponentType.PLUGIN_HUB, ModulePlugin
)
tool_names = []
for name, sub_tool in agent_module.tools._resources.items():
tool_names.append(name)
@dataclasses.dataclass
class _DynPluginPackResourceParameters(PluginPackResourceParameters):
tool_name: str = dataclasses.field(
metadata={"help": "Tool name", "valid_values": tool_names}
)
return _DynPluginPackResourceParameters
def preload_resource(self):
"""Preload the resource."""
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
ComponentType.PLUGIN_HUB, ModulePlugin
)
tool = agent_module.tools._resources.get(self._tool_name)
if not tool:
raise ValueError(f"Tool {self._tool_name} not found")
self._resources = {tool.name: tool}