refactor(agent): Refactor resource of agents (#1518)

This commit is contained in:
Fangyin Cheng
2024-05-15 09:57:19 +08:00
committed by GitHub
parent db4d318a5f
commit 559affe87d
102 changed files with 2633 additions and 2549 deletions

View File

@@ -17,7 +17,8 @@ from dbgpt.agent.core.memory.gpts.gpts_memory import GptsMemory
from dbgpt.agent.core.plan import AutoPlanChatManager, DefaultAWELLayoutManager
from dbgpt.agent.core.schema import Status
from dbgpt.agent.core.user_proxy_agent import UserProxyAgent
from dbgpt.agent.resource.resource_loader import ResourceLoader
from dbgpt.agent.resource.base import Resource
from dbgpt.agent.resource.manage import get_resource_manager
from dbgpt.agent.util.llm.llm import LLMConfig, LLMStrategyType
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.app.scene.base import ChatScene
@@ -32,9 +33,6 @@ from dbgpt.util.json_utils import serialize
from ..db.gpts_app import GptsApp, GptsAppDao, GptsAppQuery
from ..db.gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
from ..db.gpts_manage_db import GptsInstanceEntity
from ..resource_loader.datasource_load_client import DatasourceLoadClient
from ..resource_loader.knowledge_space_load_client import KnowledgeSpaceLoadClient
from ..resource_loader.plugin_hub_load_client import PluginHubLoadClient
from ..team.base import TeamMode
from .db_gpts_memory import MetaDbGptsMessageMemory, MetaDbGptsPlansMemory
@@ -93,8 +91,6 @@ class MultiAgents(BaseComponent, ABC):
from dbgpt.agent.core.memory.hybrid import HybridMemory
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
memory_key = f"{dbgpts_name}_{conv_id}"
if memory_key in self.agent_memory_map:
@@ -105,13 +101,17 @@ class MultiAgents(BaseComponent, ABC):
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
vstore_name = f"_chroma_agent_memory_{dbgpts_name}_{conv_id}"
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=VectorStoreConfig(
name=vstore_name, embedding_fn=embedding_fn
),
# Just use chroma store now
# vector_store_connector = VectorStoreConnector(
# vector_store_type=CFG.VECTOR_STORE_TYPE,
# vector_store_config=VectorStoreConfig(
# name=vstore_name, embedding_fn=embedding_fn
# ),
# )
memory = HybridMemory[AgentMemoryFragment].from_chroma(
vstore_name=vstore_name,
embeddings=embedding_fn,
)
memory = HybridMemory[AgentMemoryFragment].from_vstore(vector_store_connector)
agent_memory = AgentMemory(memory, gpts_memory=self.memory)
self.agent_memory_map[memory_key] = agent_memory
return agent_memory
@@ -243,14 +243,7 @@ class MultiAgents(BaseComponent, ABC):
agent_memory: Optional[AgentMemory] = None,
):
employees: List[Agent] = []
# Prepare resource loader
resource_loader = ResourceLoader()
plugin_hub_loader = PluginHubLoadClient()
resource_loader.register_resource_api(plugin_hub_loader)
datasource_loader = DatasourceLoadClient()
resource_loader.register_resource_api(datasource_loader)
knowledge_space_loader = KnowledgeSpaceLoadClient()
resource_loader.register_resource_api(knowledge_space_loader)
rm = get_resource_manager()
context: AgentContext = AgentContext(
conv_id=conv_uid,
gpts_app_name=gpts_app.app_name,
@@ -264,6 +257,7 @@ class MultiAgents(BaseComponent, ABC):
).create()
self.llm_provider = DefaultLLMClient(worker_manager, auto_convert_message=True)
depend_resource: Optional[Resource] = None
for record in gpts_app.details:
cls: Type[ConversableAgent] = get_agent_manager().get_by_name(
record.agent_name
@@ -273,12 +267,13 @@ class MultiAgents(BaseComponent, ABC):
llm_strategy=LLMStrategyType(record.llm_strategy),
strategy_context=record.llm_strategy_value,
)
depend_resource = rm.build_resource(record.resources, version="v1")
agent = (
await cls()
.bind(context)
.bind(llm_config)
.bind(record.resources)
.bind(resource_loader)
.bind(depend_resource)
.bind(agent_memory)
.build()
)
@@ -298,7 +293,7 @@ class MultiAgents(BaseComponent, ABC):
manager = (
await manager.bind(context)
.bind(llm_config)
.bind(resource_loader)
.bind(depend_resource)
.bind(agent_memory)
.build()
)
@@ -308,7 +303,7 @@ class MultiAgents(BaseComponent, ABC):
user_proxy: UserProxyAgent = (
await UserProxyAgent()
.bind(context)
.bind(resource_loader)
.bind(depend_resource)
.bind(agent_memory)
.build()
)

View File

@@ -4,10 +4,8 @@ from fastapi import APIRouter
from dbgpt._private.config import Config
from dbgpt.agent.core.agent_manage import get_agent_manager
from dbgpt.agent.resource.resource_api import ResourceType
from dbgpt.agent.resource.manage import get_resource_manager
from dbgpt.agent.util.llm.llm import LLMStrategyType
from dbgpt.app.knowledge.api import knowledge_space_service
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.serve.agent.app.gpts_server import available_llms
from dbgpt.serve.agent.db.gpts_app import (
@@ -17,7 +15,6 @@ from dbgpt.serve.agent.db.gpts_app import (
GptsAppQuery,
GptsAppResponse,
)
from dbgpt.serve.agent.hub.plugin_hub import plugin_hub
from dbgpt.serve.agent.team.base import TeamMode
CFG = Config()
@@ -109,7 +106,8 @@ async def team_mode_list():
@router.get("/v1/resource-type/list")
async def team_mode_list():
try:
return Result.succ([type.value for type in ResourceType])
resources = get_resource_manager().get_supported_resources(version="v1")
return Result.succ(list(resources.keys()))
except Exception as ex:
return Result.failed(code="E000X", msg=f"query resource type list error: {ex}")
@@ -146,29 +144,8 @@ async def app_resources(
Get agent resources, such as db, knowledge, internet, plugin.
"""
try:
results = []
match type:
case ResourceType.DB.value:
dbs = CFG.local_db_manager.get_db_list()
results = [db["db_name"] for db in dbs]
if name:
results = [r for r in results if name in r]
case ResourceType.Knowledge.value:
knowledge_spaces = knowledge_space_service.get_knowledge_space(
KnowledgeSpaceRequest()
)
results = [ks.name for ks in knowledge_spaces]
if name:
results = [r for r in results if name in r]
case ResourceType.Plugin.value:
plugins = plugin_hub.get_my_plugin(user_code)
results = [plugin.name for plugin in plugins]
if name:
results = [r for r in results if name in r]
case ResourceType.Internet.value:
return Result.succ(None)
case ResourceType.File.value:
return Result.succ(None)
resources = get_resource_manager().get_supported_resources("v1")
results = resources.get(type, [])
return Result.succ(results)
except Exception as ex:
return Result.failed(code="E000X", msg=f"query app resources error: {ex}")

View File

@@ -15,7 +15,7 @@ from dbgpt._private.pydantic import (
model_validator,
)
from dbgpt.agent.core.plan import AWELTeamContext
from dbgpt.agent.resource.resource_api import AgentResource
from dbgpt.agent.resource.base import AgentResource
from dbgpt.serve.agent.team.base import TeamMode
from dbgpt.storage.metadata import BaseDao, Model

View File

@@ -4,8 +4,8 @@ from typing import List
from fastapi import APIRouter, Body, File, UploadFile
from dbgpt.agent.plugin.generator import PluginPromptGenerator
from dbgpt.agent.plugin.plugins_util import scan_plugins
from dbgpt.agent.resource.tool.autogpt.plugins_util import scan_plugins
from dbgpt.agent.resource.tool.pack import AutoGPTPluginToolPack
from dbgpt.app.openapi.api_view_model import Result
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.configs.model_config import PLUGINS_DIR
@@ -30,25 +30,15 @@ class ModulePlugin(BaseComponent, ABC):
def __init__(self):
# load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
self.refresh_plugins()
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
def load_select_plugin(
self, generator: PluginPromptGenerator, select_plugins: List[str]
) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
if plugin._name in select_plugins:
if not plugin.can_handle_post_prompt():
continue
generator = plugin.post_prompt(generator)
return generator
self.tools = AutoGPTPluginToolPack(PLUGINS_DIR)
self.tools.preload_resource()
module_plugin = ModulePlugin()

View File

@@ -9,7 +9,7 @@ from typing import Any
from fastapi import UploadFile
from dbgpt.agent.core.schema import PluginStorageType
from dbgpt.agent.plugin.plugins_util import scan_plugins, update_from_git
from dbgpt.agent.resource.tool.autogpt.plugins_util import scan_plugins, update_from_git
from dbgpt.configs.model_config import PLUGINS_DIR
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity

View File

@@ -0,0 +1,95 @@
import dataclasses
import logging
from typing import Any, List, Optional, Type, Union, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
from dbgpt.util import ParameterDescription
CFG = Config()
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class DatasourceDBParameters(DBParameters):
"""The DB parameters for the datasource."""
db_name: str = dataclasses.field(metadata={"help": "DB name"})
@classmethod
def _resource_version(cls) -> str:
"""Return the resource version."""
return "v1"
@classmethod
def to_configurations(
cls,
parameters: Type["DatasourceDBParameters"],
version: Optional[str] = None,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription], super().to_configurations(parameters)
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "db_name":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> "DatasourceDBParameters":
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "db_name" not in copied_data and "value" in copied_data:
copied_data["db_name"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
class DatasourceResource(RDBMSConnectorResource):
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
conn = CFG.local_db_manager.get_connector(db_name)
super().__init__(name, connector=conn, db_name=db_name, **kwargs)
@classmethod
def resource_parameters_class(cls) -> Type[DatasourceDBParameters]:
dbs = CFG.local_db_manager.get_db_list()
results = [db["db_name"] for db in dbs]
@dataclasses.dataclass
class _DynDBParameters(DatasourceDBParameters):
db_name: str = dataclasses.field(
metadata={"help": "DB name", "valid_values": results}
)
return _DynDBParameters
def get_schema_link(
self, db: str, question: Optional[str] = None
) -> Union[str, List[str]]:
"""Return the schema link of the database."""
try:
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try:
table_infos = client.get_db_summary(
db,
question,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e:
logger.warning(f"db summary find error!{str(e)}")
if not table_infos:
conn = CFG.local_db_manager.get_connector(db)
table_infos = conn.table_simple_info()
return table_infos

View File

@@ -0,0 +1,89 @@
import dataclasses
import logging
from typing import Any, List, Optional, Type, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.knowledge import (
RetrieverResource,
RetrieverResourceParameters,
)
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt.util import ParameterDescription
CFG = Config()
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class KnowledgeSpaceLoadResourceParameters(RetrieverResourceParameters):
space_name: str = dataclasses.field(
default=None, metadata={"help": "Knowledge space name"}
)
@classmethod
def _resource_version(cls) -> str:
"""Return the resource version."""
return "v1"
@classmethod
def to_configurations(
cls,
parameters: Type["KnowledgeSpaceLoadResourceParameters"],
version: Optional[str] = None,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription], super().to_configurations(parameters)
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "space_name":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> "KnowledgeSpaceLoadResourceParameters":
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "space_name" not in copied_data and "value" in copied_data:
copied_data["space_name"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
class KnowledgeSpaceRetrieverResource(RetrieverResource):
"""Knowledge Space retriever resource."""
def __init__(self, name: str, space_name: str):
retriever = KnowledgeSpaceRetriever(space_name=space_name)
super().__init__(name, retriever=retriever)
@classmethod
def resource_parameters_class(cls) -> Type[KnowledgeSpaceLoadResourceParameters]:
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.knowledge.service import KnowledgeService
knowledge_space_service = KnowledgeService()
knowledge_spaces = knowledge_space_service.get_knowledge_space(
KnowledgeSpaceRequest()
)
results = [ks.name for ks in knowledge_spaces]
@dataclasses.dataclass
class _DynamicKnowledgeSpaceLoadResourceParameters(
KnowledgeSpaceLoadResourceParameters
):
space_name: str = dataclasses.field(
default=None,
metadata={
"help": "Knowledge space name",
"valid_values": results,
},
)
return _DynamicKnowledgeSpaceLoadResourceParameters

View File

@@ -0,0 +1,92 @@
import dataclasses
import logging
from typing import Any, List, Optional, Type, cast
from dbgpt._private.config import Config
from dbgpt.agent.resource.pack import PackResourceParameters
from dbgpt.agent.resource.tool.pack import ToolPack
from dbgpt.component import ComponentType
from dbgpt.serve.agent.hub.controller import ModulePlugin
from dbgpt.util.parameter_utils import ParameterDescription
CFG = Config()
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class PluginPackResourceParameters(PackResourceParameters):
tool_name: str = dataclasses.field(metadata={"help": "Tool name"})
@classmethod
def _resource_version(cls) -> str:
"""Return the resource version."""
return "v1"
@classmethod
def to_configurations(
cls,
parameters: Type["PluginPackResourceParameters"],
version: Optional[str] = None,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription], super().to_configurations(parameters)
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "tool_name":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> "PluginPackResourceParameters":
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "tool_name" not in copied_data and "value" in copied_data:
copied_data["tool_name"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
class PluginToolPack(ToolPack):
def __init__(self, tool_name: str, **kwargs):
kwargs.pop("name")
super().__init__([], name="Plugin Tool Pack", **kwargs)
# Select tool name
self._tool_name = tool_name
@classmethod
def type_alias(cls) -> str:
return "tool(autogpt_plugins)"
@classmethod
def resource_parameters_class(cls) -> Type[PluginPackResourceParameters]:
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
ComponentType.PLUGIN_HUB, ModulePlugin
)
tool_names = []
for name, sub_tool in agent_module.tools._resources.items():
tool_names.append(name)
@dataclasses.dataclass
class _DynPluginPackResourceParameters(PluginPackResourceParameters):
tool_name: str = dataclasses.field(
metadata={"help": "Tool name", "valid_values": tool_names}
)
return _DynPluginPackResourceParameters
def preload_resource(self):
"""Preload the resource."""
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
ComponentType.PLUGIN_HUB, ModulePlugin
)
tool = agent_module.tools._resources.get(self._tool_name)
if not tool:
raise ValueError(f"Tool {self._tool_name} not found")
self._resources = {tool.name: tool}

View File

@@ -1,66 +0,0 @@
import logging
from typing import List, Optional, Union
from dbgpt._private.config import Config
from dbgpt.agent.resource.resource_api import AgentResource
from dbgpt.agent.resource.resource_db_api import ResourceDbClient
from dbgpt.component import ComponentType
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.tracer import root_tracer
CFG = Config()
logger = logging.getLogger(__name__)
class DatasourceLoadClient(ResourceDbClient):
def __init__(self):
super().__init__()
# The executor to submit blocking function
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
def get_data_type(self, resource: AgentResource) -> str:
conn = CFG.local_db_manager.get_connector(resource.value)
return conn.db_type
async def get_schema_link(
self, db: str, question: Optional[str] = None
) -> Union[str, List[str]]:
try:
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
except ImportError:
raise ValueError("Could not import DBSummaryClient. ")
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
table_infos = None
try:
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,
db,
question,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e:
print("db summary find error!" + str(e))
if not table_infos:
conn = CFG.local_db_manager.get_connector(db)
table_infos = await blocking_func_to_async(
self._executor, conn.table_simple_info
)
return table_infos
async def query_to_df(self, db: str, sql: str):
conn = CFG.local_db_manager.get_connector(db)
return conn.run_to_df(sql)
async def query(self, db: str, sql: str):
conn = CFG.local_db_manager.get_connector(db)
return conn.query_ex(sql)
async def run_sql(self, db: str, sql: str):
conn = CFG.local_db_manager.get_connector(db)
return conn.run(sql)

View File

@@ -1,35 +0,0 @@
import logging
from typing import Any, List, Optional, Union
from dbgpt._private.config import Config
from dbgpt.agent.resource.resource_api import AgentResource
from dbgpt.agent.resource.resource_knowledge_api import ResourceKnowledgeClient
from dbgpt.core import Chunk
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
CFG = Config()
logger = logging.getLogger(__name__)
class KnowledgeSpaceLoadClient(ResourceKnowledgeClient):
async def get_space_desc(self, space_name) -> str:
pass
async def get_kn(
self, space_name: str, question: Optional[str] = None
) -> List[Chunk]:
kn_retriver = KnowledgeSpaceRetriever(space_name=space_name)
chunks: List[Chunk] = kn_retriver.retrieve(question)
return chunks
async def add_kn(
self, space_name: str, kn_name: str, type: str, content: Optional[Any]
):
kn_retriver = KnowledgeSpaceRetriever(space_name=space_name)
async def get_data_introduce(
self, resource: AgentResource, question: Optional[str] = None
) -> Union[str, List[str]]:
docs = await self.get_kn(resource.value, question)
return "\n".join([doc.content for doc in docs])

View File

@@ -1,40 +0,0 @@
import json
import logging
from typing import Optional
from dbgpt._private.config import Config
from dbgpt.agent.plugin.generator import PluginPromptGenerator
from dbgpt.agent.resource.resource_plugin_api import ResourcePluginClient
from dbgpt.component import ComponentType
from dbgpt.serve.agent.hub.controller import ModulePlugin
from dbgpt.util.executor_utils import ExecutorFactory
CFG = Config()
logger = logging.getLogger(__name__)
class PluginHubLoadClient(ResourcePluginClient):
def __init__(self):
super().__init__()
# The executor to submit blocking function
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
async def load_plugin(
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
) -> PluginPromptGenerator:
logger.info(f"PluginHubLoadClient load plugin:{value}")
if plugin_generator is None:
plugin_generator = PluginPromptGenerator()
plugin_generator.set_command_registry(CFG.command_registry)
agent_module = CFG.SYSTEM_APP.get_component(
ComponentType.PLUGIN_HUB, ModulePlugin
)
plugin_generator = agent_module.load_select_plugin(
plugin_generator, json.dumps(value)
)
return plugin_generator