Files
DB-GPT/dbgpt/agent/agents/agents_manage.py
明天 d5afa6e206 Native data AI application framework based on AWEL+AGENT (#1152)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: lcx01800250 <lcx01800250@alibaba-inc.com>
Co-authored-by: licunxing <864255598@qq.com>
Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: xuyuan23 <643854343@qq.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: hzh97 <2976151305@qq.com>
2024-02-07 17:43:27 +08:00

101 lines
3.1 KiB
Python

import logging
import re
from collections import defaultdict
from typing import Dict, List, Optional, Type
from .agent import Agent
from .expand.code_assistant_agent import CodeAssistantAgent
from .expand.dashboard_assistant_agent import DashboardAssistantAgent
from .expand.data_scientist_agent import DataScientistAgent
from .expand.plugin_assistant_agent import PluginAssistantAgent
from .expand.summary_assistant_agent import SummaryAssistantAgent
logger = logging.getLogger(__name__)
def get_all_subclasses(cls):
all_subclasses = []
direct_subclasses = cls.__subclasses__()
all_subclasses.extend(direct_subclasses)
for subclass in direct_subclasses:
all_subclasses.extend(get_all_subclasses(subclass))
return all_subclasses
def participant_roles(agents: List[Agent] = None) -> str:
# Default to all agents registered
if agents is None:
agents = agents
roles = []
for agent in agents:
if agent.system_message.strip() == "":
logger.warning(
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
)
roles.append(f"{agent.name}: {agent.describe}")
return "\n".join(roles)
def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict:
"""
Finds and counts agent mentions in the string message_content, taking word boundaries into account.
Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur)
"""
mentions = dict()
for agent in agents:
regex = (
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
) # Finds agent mentions, taking word boundaries into account
count = len(
re.findall(regex, " " + message_content + " ")
) # Pad the message to help with matching
if count > 0:
mentions[agent.name] = count
return mentions
class AgentsManage:
def __init__(self):
self._agents = defaultdict()
def register_agent(self, cls):
self._agents[cls().profile] = cls
def get_by_name(self, name: str) -> Optional[Type[Agent]]:
if name not in self._agents:
raise ValueError(f"Agent:{name} not register!")
return self._agents[name]
def get_describe_by_name(self, name: str) -> Optional[Type[Agent]]:
return self._agents[name].DEFAULT_DESCRIBE
def all_agents(self):
result = {}
for name, cls in self._agents.items():
result[name] = cls.DEFAULT_DESCRIBE
return result
def list_agents(self):
result = []
for name, cls in self._agents.items():
instance = cls()
result.append(
{
"name": instance.profile,
"desc": instance.goal,
}
)
return result
agent_manage = AgentsManage()
agent_manage.register_agent(CodeAssistantAgent)
agent_manage.register_agent(DashboardAssistantAgent)
agent_manage.register_agent(DataScientistAgent)
agent_manage.register_agent(SummaryAssistantAgent)
agent_manage.register_agent(PluginAssistantAgent)