mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
0
dbgpt/serve/agent/resource/__init__.py
Normal file
0
dbgpt/serve/agent/resource/__init__.py
Normal file
95
dbgpt/serve/agent/resource/datasource.py
Normal file
95
dbgpt/serve/agent/resource/datasource.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
|
||||
from dbgpt.util import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DatasourceDBParameters(DBParameters):
|
||||
"""The DB parameters for the datasource."""
|
||||
|
||||
db_name: str = dataclasses.field(metadata={"help": "DB name"})
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["DatasourceDBParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "db_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "DatasourceDBParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "db_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["db_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class DatasourceResource(RDBMSConnectorResource):
|
||||
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
|
||||
conn = CFG.local_db_manager.get_connector(db_name)
|
||||
super().__init__(name, connector=conn, db_name=db_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[DatasourceDBParameters]:
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
results = [db["db_name"] for db in dbs]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynDBParameters(DatasourceDBParameters):
|
||||
db_name: str = dataclasses.field(
|
||||
metadata={"help": "DB name", "valid_values": results}
|
||||
)
|
||||
|
||||
return _DynDBParameters
|
||||
|
||||
def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
try:
|
||||
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
table_infos = None
|
||||
try:
|
||||
table_infos = client.get_db_summary(
|
||||
db,
|
||||
question,
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"db summary find error!{str(e)}")
|
||||
if not table_infos:
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
table_infos = conn.table_simple_info()
|
||||
|
||||
return table_infos
|
89
dbgpt/serve/agent/resource/knowledge.py
Normal file
89
dbgpt/serve/agent/resource/knowledge.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.knowledge import (
|
||||
RetrieverResource,
|
||||
RetrieverResourceParameters,
|
||||
)
|
||||
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
from dbgpt.util import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KnowledgeSpaceLoadResourceParameters(RetrieverResourceParameters):
|
||||
space_name: str = dataclasses.field(
|
||||
default=None, metadata={"help": "Knowledge space name"}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["KnowledgeSpaceLoadResourceParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "space_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "KnowledgeSpaceLoadResourceParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "space_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["space_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class KnowledgeSpaceRetrieverResource(RetrieverResource):
|
||||
"""Knowledge Space retriever resource."""
|
||||
|
||||
def __init__(self, name: str, space_name: str):
|
||||
retriever = KnowledgeSpaceRetriever(space_name=space_name)
|
||||
super().__init__(name, retriever=retriever)
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[KnowledgeSpaceLoadResourceParameters]:
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
|
||||
knowledge_space_service = KnowledgeService()
|
||||
knowledge_spaces = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest()
|
||||
)
|
||||
results = [ks.name for ks in knowledge_spaces]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamicKnowledgeSpaceLoadResourceParameters(
|
||||
KnowledgeSpaceLoadResourceParameters
|
||||
):
|
||||
space_name: str = dataclasses.field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Knowledge space name",
|
||||
"valid_values": results,
|
||||
},
|
||||
)
|
||||
|
||||
return _DynamicKnowledgeSpaceLoadResourceParameters
|
92
dbgpt/serve/agent/resource/plugin.py
Normal file
92
dbgpt/serve/agent/resource/plugin.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.pack import PackResourceParameters
|
||||
from dbgpt.agent.resource.tool.pack import ToolPack
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.serve.agent.hub.controller import ModulePlugin
|
||||
from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PluginPackResourceParameters(PackResourceParameters):
|
||||
tool_name: str = dataclasses.field(metadata={"help": "Tool name"})
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["PluginPackResourceParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "tool_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "PluginPackResourceParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "tool_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["tool_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class PluginToolPack(ToolPack):
|
||||
def __init__(self, tool_name: str, **kwargs):
|
||||
kwargs.pop("name")
|
||||
super().__init__([], name="Plugin Tool Pack", **kwargs)
|
||||
# Select tool name
|
||||
self._tool_name = tool_name
|
||||
|
||||
@classmethod
|
||||
def type_alias(cls) -> str:
|
||||
return "tool(autogpt_plugins)"
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[PluginPackResourceParameters]:
|
||||
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
tool_names = []
|
||||
for name, sub_tool in agent_module.tools._resources.items():
|
||||
tool_names.append(name)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynPluginPackResourceParameters(PluginPackResourceParameters):
|
||||
tool_name: str = dataclasses.field(
|
||||
metadata={"help": "Tool name", "valid_values": tool_names}
|
||||
)
|
||||
|
||||
return _DynPluginPackResourceParameters
|
||||
|
||||
def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
tool = agent_module.tools._resources.get(self._tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {self._tool_name} not found")
|
||||
self._resources = {tool.name: tool}
|
Reference in New Issue
Block a user