feat(agent): add app starter role in mutli agent (#2265)

Co-authored-by: cinjospeh <joseph.cjn@alibaba-inc.com>
This commit is contained in:
cinjoseph 2025-01-03 15:04:09 +08:00 committed by GitHub
parent ad1e8e27a5
commit 0e3b2dc818
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 671 additions and 0 deletions

166
dbgpt/agent/resource/app.py Normal file
View File

@ -0,0 +1,166 @@
import dataclasses
import uuid
from typing import Optional, Tuple, Dict, Type, Any, List, cast
from dbgpt.agent import ConversableAgent, AgentMessage, AgentContext
from dbgpt.serve.agent.agents.app_agent_manage import get_app_manager
from dbgpt.util import ParameterDescription
from .base import Resource, ResourceParameters, ResourceType
def get_app_list():
apps = get_app_manager().get_dbgpts()
results = [
{
"label": f"{app.app_name}({app.app_code})",
"key": app.app_code,
"description": app.app_describe
}
for app in apps
]
return results
@dataclasses.dataclass
class AppResourceParameters(ResourceParameters):
app_code: str = dataclasses.field(
default=None,
metadata={
"help": "app code",
"valid_values": get_app_list(),
},
)
@classmethod
def to_configurations(
cls,
parameters: Type["AppResourceParameters"],
version: Optional[str] = None,
**kwargs,
) -> Any:
"""Convert the parameters to configurations."""
conf: List[ParameterDescription] = cast(
List[ParameterDescription], super().to_configurations(parameters)
)
version = version or cls._resource_version()
if version != "v1":
return conf
# Compatible with old version
for param in conf:
if param.param_name == "app_code":
return param.valid_values or []
return []
@classmethod
def from_dict(
cls, data: dict, ignore_extra_fields: bool = True
) -> ResourceParameters:
"""Create a new instance from a dictionary."""
copied_data = data.copy()
if "app_code" not in copied_data and "value" in copied_data:
copied_data["app_code"] = copied_data.pop("value")
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
class AppResource(Resource[AppResourceParameters]):
"""AppResource resource class."""
def __init__(self, name: str, app_code: str, **kwargs):
self._resource_name = name
self._app_code = app_code
app = get_app_manager().get_app(self._app_code)
self._app_name = app.app_name
self._app_desc = app.app_describe
@classmethod
def type(cls) -> ResourceType:
return ResourceType.App
@property
def name(self) -> str:
return self._resource_name
@classmethod
def resource_parameters_class(cls, **kwargs) -> Type[ResourceParameters]:
"""Return the resource parameters class."""
return AppResourceParameters
async def get_prompt(self, *, lang: str = "en", prompt_type: str = "default", question: Optional[str] = None,
resource_name: Optional[str] = None, **kwargs) -> Tuple[str, Optional[Dict]]:
"""Get the prompt."""
prompt_template_zh = (
"{name}:调用此资源与应用 {app_name} 进行交互。"
"应用 {app_name} 有什么用?{description}"
)
prompt_template_en = (
"{name}Call this resource to interact with the application {app_name} ."
"What is the application {app_name} useful for? {description} "
)
template = prompt_template_en if lang == "en" else prompt_template_zh
return (
template.format(
name=self.name,
app_name=self._app_name,
description=self._app_desc
),
None,
)
@property
def is_async(self) -> bool:
"""Return whether the tool is asynchronous."""
return True
async def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any:
if self.is_async:
raise RuntimeError("Async execution is not supported")
async def async_execute(
self,
*args,
resource_name: Optional[str] = None,
**kwargs,
) -> Any:
"""Execute the tool asynchronously.
Args:
*args: The positional arguments.
resource_name (str, optional): The tool name to be executed(not used for
specific tool).
**kwargs: The keyword arguments.
"""
user_input = kwargs.get("user_input")
parent_agent = kwargs.get("parent_agent")
reply_message = await self.chat_2_app_once(self._app_code, user_input=user_input, sender=parent_agent)
return reply_message.content
async def chat_2_app_once(self,
app_code: str,
user_input: str,
conv_uid: str = None,
sender: ConversableAgent = None) -> AgentMessage:
# create a new conv_uid
conv_uid = str(uuid.uuid4()) if conv_uid is None else conv_uid
gpts_app = get_app_manager().get_app(app_code)
app_agent = await get_app_manager().create_agent_by_app_code(gpts_app, conv_uid=conv_uid)
agent_message = AgentMessage(
content=user_input,
current_goal=user_input,
context={
"conv_uid": conv_uid,
},
rounds=0,
)
reply_message: AgentMessage = await app_agent.generate_reply(received_message=agent_message,
sender=sender)
return reply_message

View File

@ -28,6 +28,7 @@ class ResourceType(str, Enum):
ExcelFile = "excel_file"
ImageFile = "image_file"
AWELFlow = "awel_flow"
App = "app"
# Resource type for resource pack
Pack = "pack"

View File

@ -111,12 +111,14 @@ def _initialize_resource_manager(system_app: SystemApp):
from dbgpt.serve.agent.resource.datasource import DatasourceResource
from dbgpt.serve.agent.resource.knowledge import KnowledgeSpaceRetrieverResource
from dbgpt.serve.agent.resource.plugin import PluginToolPack
from dbgpt.agent.resource.app import AppResource
initialize_resource(system_app)
rm = get_resource_manager(system_app)
rm.register_resource(DatasourceResource)
rm.register_resource(KnowledgeSpaceRetrieverResource)
rm.register_resource(PluginToolPack, resource_type=ResourceType.Tool)
rm.register_resource(AppResource)
# Register a search tool
rm.register_resource(resource_instance=baidu_search)
rm.register_resource(resource_instance=list_dbgpt_support_models)

View File

@ -0,0 +1,276 @@
import logging
import uuid
from abc import ABC
from typing import List, Type
from dbgpt.agent import (
AgentContext,
AgentMemory,
ConversableAgent,
DefaultAWELLayoutManager,
GptsMemory,
LLMConfig,
UserProxyAgent, get_agent_manager,
)
from dbgpt.agent.core.schema import Status
from dbgpt.agent.resource import get_resource_manager
from dbgpt.agent.util.llm.llm import LLMStrategyType
from dbgpt.app.component_configs import CFG
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import LLMClient
from dbgpt.core import PromptTemplate
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient
from dbgpt.serve.prompt.api.endpoints import get_service
from .db_gpts_memory import MetaDbGptsMessageMemory, MetaDbGptsPlansMemory
from ..db import GptsMessagesDao
from ..db.gpts_app import GptsApp, GptsAppDao, GptsAppQuery
from ..db.gpts_app import GptsAppDetail
from ..db.gpts_conversations_db import GptsConversationsDao
from ..team.base import TeamMode
logger = logging.getLogger(__name__)
class AppManager(BaseComponent, ABC):
name = "dbgpt_agent_app_manager"
def __init__(self, system_app: SystemApp):
self.gpts_conversations = GptsConversationsDao()
self.gpts_messages_dao = GptsMessagesDao()
self.gpts_app = GptsAppDao()
self.memory = GptsMemory(
plans_memory=MetaDbGptsPlansMemory(),
message_memory=MetaDbGptsMessageMemory(),
)
self.agent_memory_map = {}
super().__init__(system_app)
self.system_app = system_app
def init_app(self, system_app: SystemApp):
self.system_app = system_app
def get_dbgpts(self, user_code: str = None, sys_code: str = None):
apps = self.gpts_app.app_list(
GptsAppQuery(user_code=user_code, sys_code=sys_code)
).app_list
return apps
def get_app(self, app_code) -> GptsApp:
"""get app"""
return self.gpts_app.app_detail(app_code)
async def user_chat_2_app(
self,
user_query: str,
conv_uid: str,
gpts_app: GptsApp,
agent_memory: AgentMemory,
is_retry_chat: bool = False,
last_speaker_name: str = None,
init_message_rounds: int = 0,
enable_verbose: bool = True,
**ext_info,
) -> Status:
context: AgentContext = AgentContext(
conv_id=conv_uid,
gpts_app_code=gpts_app.app_code,
gpts_app_name=gpts_app.app_name,
language=gpts_app.language,
enable_vis_message=enable_verbose,
)
recipient = await self.create_app_agent(gpts_app, agent_memory, context)
if is_retry_chat:
# retry chat
self.gpts_conversations.update(conv_uid, Status.RUNNING.value)
# start user proxy
user_proxy: UserProxyAgent = (
await UserProxyAgent().bind(context).bind(agent_memory).build()
)
await user_proxy.initiate_chat(
recipient=recipient,
message=user_query,
is_retry_chat=is_retry_chat,
last_speaker_name=last_speaker_name,
message_rounds=init_message_rounds,
**ext_info,
)
# Check if the user has received a question.
if user_proxy.have_ask_user():
return Status.WAITING
return Status.COMPLETE
async def create_app_agent(
self,
gpts_app: GptsApp,
agent_memory: AgentMemory,
context: AgentContext,
) -> ConversableAgent:
# init default llm provider
llm_provider = DefaultLLMClient(
self.system_app.get_component(ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory).create(),
auto_convert_message=True
)
# init team employees
# TODO employee has it own llm provider
employees: List[ConversableAgent] = []
for record in gpts_app.details:
agent = (await create_agent_from_gpt_detail(record, llm_provider, context, agent_memory))
agent.name_prefix = gpts_app.app_name
employees.append(agent)
app_agent: ConversableAgent = (
await create_agent_of_gpts_app(gpts_app,
llm_provider,
context,
agent_memory,
employees)
)
app_agent.name_prefix = gpts_app.app_name
return app_agent
async def create_agent_by_app_code(
self,
gpts_app: GptsApp,
conv_uid: str = None,
agent_memory: AgentMemory = None,
context: AgentContext = None,
) -> ConversableAgent:
"""
Create a conversable agent by application code.
Parameters:
gpts_app (str): The application.
conv_uid (str, optional): The unique identifier of the conversation, default is None. If not provided, a new UUID will be generated.
agent_memory (AgentMemory, optional): The memory object for the agent, default is None. If not provided, a default memory object will be created.
context (AgentContext, optional): The context object for the agent, default is None. If not provided, a default context object will be created.
Returns:
ConversableAgent: The created conversable agent object.
"""
conv_uid = str(uuid.uuid4()) if conv_uid is None else conv_uid
from dbgpt.agent.core.memory.gpts import DefaultGptsPlansMemory, DefaultGptsMessageMemory
if agent_memory is None:
gpt_memory = GptsMemory(
plans_memory=DefaultGptsPlansMemory(),
message_memory=DefaultGptsMessageMemory(),
)
gpt_memory.init(conv_uid)
agent_memory = AgentMemory(gpts_memory=gpt_memory)
if context is None:
context: AgentContext = AgentContext(
conv_id=conv_uid,
gpts_app_code=gpts_app.app_code,
gpts_app_name=gpts_app.app_name,
language=gpts_app.language,
enable_vis_message=False,
)
context.gpts_app_code = gpts_app.app_code
context.gpts_app_name = gpts_app.app_name
context.language = gpts_app.language
agent: ConversableAgent = (
await self.create_app_agent(gpts_app, agent_memory, context)
)
return agent
async def create_agent_from_gpt_detail(
record: GptsAppDetail,
llm_client: LLMClient,
agent_context: AgentContext,
agent_memory: AgentMemory) -> ConversableAgent:
"""
Get the agent object from the GPTsAppDetail object.
"""
agent_manager = get_agent_manager()
agent_cls: Type[ConversableAgent] = agent_manager.get_by_name(
record.agent_name
)
llm_config = LLMConfig(
llm_client=llm_client,
llm_strategy=LLMStrategyType(record.llm_strategy),
strategy_context=record.llm_strategy_value,
)
prompt_template = None
if record.prompt_template:
prompt_template: PromptTemplate = get_service().get_template(
prompt_code=record.prompt_template
)
depend_resource = get_resource_manager().build_resource(record.resources, version="v1")
agent = (await agent_cls()
.bind(agent_context)
.bind(agent_memory)
.bind(llm_config)
.bind(depend_resource)
.bind(prompt_template)
.build())
return agent
async def create_agent_of_gpts_app(
gpts_app: GptsApp,
llm_client: LLMClient,
context: AgentContext,
memory: AgentMemory,
employees: List[ConversableAgent],
) -> ConversableAgent:
llm_config = LLMConfig(
llm_client=llm_client,
llm_strategy=LLMStrategyType.Default,
)
awel_team_context = gpts_app.team_context
team_mode = TeamMode(gpts_app.team_mode)
if team_mode == TeamMode.SINGLE_AGENT:
agent_of_app: ConversableAgent = employees[0]
else:
if TeamMode.AUTO_PLAN == team_mode:
if not employees or len(employees) < 0:
raise ValueError("APP exception no available agent")
from dbgpt.agent.v2 import AutoPlanChatManagerV2, MultiAgentTeamPlanner
planner = MultiAgentTeamPlanner()
planner.name_prefix = gpts_app.app_name
manager = AutoPlanChatManagerV2(planner)
manager.name_prefix = gpts_app.app_name
elif TeamMode.AWEL_LAYOUT == team_mode:
if not awel_team_context:
raise ValueError(
"Your APP has not been developed yet, please bind Flow!"
)
manager = DefaultAWELLayoutManager(dag=awel_team_context)
elif TeamMode.NATIVE_APP == team_mode:
raise ValueError(f"Native APP chat not supported!")
else:
raise ValueError(f"Unknown Agent Team Mode!{team_mode}")
manager = (
await manager.bind(context)
.bind(memory)
.bind(llm_config)
.build()
)
manager.hire(employees)
agent_of_app: ConversableAgent = manager
return agent_of_app
def get_app_manager() -> AppManager:
return app_manager
app_manager = AppManager(CFG.SYSTEM_APP)

View File

@ -0,0 +1,6 @@
from dbgpt.serve.agent.agents.expand.app_resource_start_assisant_agent import AppStarterAgent
__all__ = [
"AppStarterAgent",
]

View File

@ -0,0 +1,220 @@
import json
import logging
from typing import Any, Dict, List
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.agent import Action, ActionOutput, AgentResource, AgentMessage, ResourceType
from dbgpt.agent import (
Agent,
ConversableAgent,
get_agent_manager,
)
from dbgpt.agent.core.profile import DynConfig, ProfileConfig
from dbgpt.agent.resource.app import AppResource
from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
logger = logging.getLogger(__name__)
class AppResourceInput(BaseModel):
"""Plugin input model."""
app_name: str = Field(
...,
description="The name of a application that can be used to answer the current question"
" or solve the current task.",
)
app_query: str = Field(
...,
description="The query to the selected application",
)
class AppResourceAction(Action[AppResourceInput]):
"""AppResource action class."""
def __init__(self, **kwargs):
"""App action init."""
super().__init__(**kwargs)
self._render_protocol = VisPlugin()
@property
def resource_need(self) -> Optional[ResourceType]:
"""Return the resource type needed for the action."""
return ResourceType.App
@property
def render_protocol(self) -> Optional[Vis]:
"""Return the render protocol."""
return self._render_protocol
@property
def out_model_type(self):
"""Return the output model type."""
return AppResourceInput
@property
def ai_out_schema(self) -> Optional[str]:
"""Return the AI output schema."""
out_put_schema = {
"app_name": "the agent name you selected",
"app_query": "the query to the selected agent, must input a str, base on the natural language "
}
return f"""Please response in the following json format:
{json.dumps(out_put_schema, indent=2, ensure_ascii=False)}
Make sure the response is correct json and can be parsed by Python json.loads.
"""
async def run(
self,
ai_message: str,
resource: Optional[AgentResource] = None,
rely_action_out: Optional[ActionOutput] = None,
need_vis_render: bool = True,
**kwargs,
) -> ActionOutput:
"""Perform the plugin action.
Args:
ai_message (str): The AI message.
resource (Optional[AgentResource], optional): The resource. Defaults to
None.
rely_action_out (Optional[ActionOutput], optional): The rely action output.
Defaults to None.
need_vis_render (bool, optional): Whether need visualization rendering.
Defaults to True.
"""
try:
response_success = True
err_msg = None
app_result = None
try:
param: AppResourceInput = self._input_convert(ai_message, AppResourceInput)
except Exception as e:
logger.exception((str(e)))
return ActionOutput(
is_exe_success=False,
content="The requested correctly structured answer could not be found.",
)
app_resource = self.__get_app_resource_of_app_name(param.app_name)
try:
user_input = param.app_query
parent_agent = kwargs.get("parent_agent")
app_result = await app_resource.async_execute(
user_input=user_input,
parent_agent=parent_agent,
)
except Exception as e:
response_success = False
err_msg = f"App [{param.app_name}] execute failed! {str(e)}"
logger.exception(err_msg)
return ActionOutput(
is_exe_success=response_success,
content=str(app_result),
# view=self.__get_plugin_view(param, app_result, err_msg),
view=str(app_result),
observations=str(app_result),
)
except Exception as e:
logger.exception("App Action Run Failed")
return ActionOutput(
is_exe_success=False, content=f"App action run failed!{str(e)}"
)
async def __get_plugin_view(self, param: AppResourceInput, app_result: Any, err_msg: str):
if not self.render_protocol:
return None
# raise NotImplementedError("The render_protocol should be implemented.")
plugin_param = {
"name": param.tool_name,
"args": param.args,
"logo": None,
"result": str(app_result),
"err_msg": err_msg,
}
view = await self.render_protocol.display(content=plugin_param)
def __get_app_resource_list(self) -> List[AppResource]:
app_resource_list: List[AppResource] = []
if self.resource.type() == ResourceType.Pack:
for sub_resource in self.resource.sub_resources:
if sub_resource.type() == ResourceType.App:
app_resource_list.extend(AppResource.from_resource(sub_resource))
if self.resource.type() == ResourceType.App:
app_resource_list.extend(AppResource.from_resource(self.resource))
return app_resource_list
def __get_app_resource_of_app_name(self, app_name: str):
app_resource_list: List[AppResource] = self.__get_app_resource_list()
if app_resource_list is None or len(app_resource_list) == 0:
raise ValueError("No app resource was found")
for app_resource in app_resource_list:
if app_resource._app_name == app_name:
return app_resource
raise ValueError(f"App {app_name} not found !")
class AppStarterAgent(ConversableAgent):
profile: ProfileConfig = ProfileConfig(
name=DynConfig(
"AppStarter",
category="agent",
key="dbgpt_ant_agent_agents_app_resource_starter_assistant_agent_profile_name",
),
role=DynConfig(
"App Starter",
category="agent",
key="dbgpt_ant_agent_agents_app_resource_starter_assistant_agent_profile_role",
),
goal=DynConfig(
"根据用户的问题和提供的应用信息,从已知资源中选择一个合适的应用来解决和回答用户的问题,并提取用户输入的关键信息到应用意图的槽位中。",
category="agent",
key="dbgpt_ant_agent_agents_app_resource_starter_assistant_agent_profile_goal",
),
constraints=DynConfig(
[
"请一步一步思考参为用户问题选择一个最匹配的应用来进行用户问题回答,可参考给出示例的应用选择逻辑.",
"请阅读用户问题,确定问题所属领域和问题意图,按领域和意图匹配应用,如果用户问题意图缺少操作类应用需要的参数,优先使用咨询类型应用,有明确操作目标才使用操作类应用.",
"必须从已知的应用中选出一个可用的应用来进行回答,不要瞎编应用的名称",
"仅选择可回答问题的应用即可,不要直接回答用户问题.",
"如果用户的问题和提供的所有应用全都不相关则应用code和name都输出为空",
"注意应用意图定义中如果有槽位信息,再次阅读理解用户输入信息,将对应的内容填入对应槽位参数定义中.",
],
category="agent",
key="dbgpt_ant_agent_agents_app_resource_starter_assistant_agent_profile_constraints",
),
desc=DynConfig(
"根据用户问题匹配合适的应用来进行回答.",
category="agent",
key="dbgpt_ant_agent_agents_app_resource_starter_assistant_agent_profile_desc",
),
)
stream_out: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([AppResourceAction])
def prepare_act_param(
self,
received_message: Optional[AgentMessage],
sender: Agent,
rely_messages: Optional[List[AgentMessage]] = None,
**kwargs,
) -> Dict[str, Any]:
return {
"user_input": received_message.content,
"conv_id": self.agent_context.conv_id,
"parent_agent": self,
}
agent_manage = get_agent_manager()
agent_manage.register_agent(AppStarterAgent)