mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 17:39:02 +00:00
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: lcx01800250 <lcx01800250@alibaba-inc.com> Co-authored-by: licunxing <864255598@qq.com> Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: xuyuan23 <643854343@qq.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: hzh97 <2976151305@qq.com>
101 lines
3.1 KiB
Python
101 lines
3.1 KiB
Python
import logging
|
|
import re
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional, Type
|
|
|
|
from .agent import Agent
|
|
from .expand.code_assistant_agent import CodeAssistantAgent
|
|
from .expand.dashboard_assistant_agent import DashboardAssistantAgent
|
|
from .expand.data_scientist_agent import DataScientistAgent
|
|
from .expand.plugin_assistant_agent import PluginAssistantAgent
|
|
from .expand.summary_assistant_agent import SummaryAssistantAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_all_subclasses(cls):
|
|
all_subclasses = []
|
|
direct_subclasses = cls.__subclasses__()
|
|
all_subclasses.extend(direct_subclasses)
|
|
|
|
for subclass in direct_subclasses:
|
|
all_subclasses.extend(get_all_subclasses(subclass))
|
|
return all_subclasses
|
|
|
|
|
|
def participant_roles(agents: List[Agent] = None) -> str:
|
|
# Default to all agents registered
|
|
if agents is None:
|
|
agents = agents
|
|
|
|
roles = []
|
|
for agent in agents:
|
|
if agent.system_message.strip() == "":
|
|
logger.warning(
|
|
f"The agent '{agent.name}' has an empty system_message, and may not work well with GroupChat."
|
|
)
|
|
roles.append(f"{agent.name}: {agent.describe}")
|
|
return "\n".join(roles)
|
|
|
|
|
|
def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict:
|
|
"""
|
|
Finds and counts agent mentions in the string message_content, taking word boundaries into account.
|
|
|
|
Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur)
|
|
"""
|
|
mentions = dict()
|
|
for agent in agents:
|
|
regex = (
|
|
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
|
|
) # Finds agent mentions, taking word boundaries into account
|
|
count = len(
|
|
re.findall(regex, " " + message_content + " ")
|
|
) # Pad the message to help with matching
|
|
if count > 0:
|
|
mentions[agent.name] = count
|
|
return mentions
|
|
|
|
|
|
class AgentsManage:
|
|
def __init__(self):
|
|
self._agents = defaultdict()
|
|
|
|
def register_agent(self, cls):
|
|
self._agents[cls().profile] = cls
|
|
|
|
def get_by_name(self, name: str) -> Optional[Type[Agent]]:
|
|
if name not in self._agents:
|
|
raise ValueError(f"Agent:{name} not register!")
|
|
return self._agents[name]
|
|
|
|
def get_describe_by_name(self, name: str) -> Optional[Type[Agent]]:
|
|
return self._agents[name].DEFAULT_DESCRIBE
|
|
|
|
def all_agents(self):
|
|
result = {}
|
|
for name, cls in self._agents.items():
|
|
result[name] = cls.DEFAULT_DESCRIBE
|
|
return result
|
|
|
|
def list_agents(self):
|
|
result = []
|
|
for name, cls in self._agents.items():
|
|
instance = cls()
|
|
result.append(
|
|
{
|
|
"name": instance.profile,
|
|
"desc": instance.goal,
|
|
}
|
|
)
|
|
return result
|
|
|
|
|
|
agent_manage = AgentsManage()
|
|
|
|
agent_manage.register_agent(CodeAssistantAgent)
|
|
agent_manage.register_agent(DashboardAssistantAgent)
|
|
agent_manage.register_agent(DataScientistAgent)
|
|
agent_manage.register_agent(SummaryAssistantAgent)
|
|
agent_manage.register_agent(PluginAssistantAgent)
|