mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 12:37:14 +00:00
refactor(agent): Agent modular refactoring (#1487)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""DB-GPT Multi-Agents Module."""
|
||||
|
||||
from .actions.action import Action, ActionOutput # noqa: F401
|
||||
from .core.action import * # noqa: F401, F403
|
||||
from .core.agent import ( # noqa: F401
|
||||
Agent,
|
||||
AgentContext,
|
||||
@@ -13,12 +13,15 @@ from .core.agent_manage import ( # noqa: F401
|
||||
initialize_agent,
|
||||
)
|
||||
from .core.base_agent import ConversableAgent # noqa: F401
|
||||
from .core.llm.llm import LLMConfig # noqa: F401
|
||||
from .core.memory import * # noqa: F401, F403
|
||||
from .core.memory.gpts.gpts_memory import GptsMemory # noqa: F401
|
||||
from .core.plan import * # noqa: F401, F403
|
||||
from .core.profile import * # noqa: F401, F403
|
||||
from .core.schema import PluginStorageType # noqa: F401
|
||||
from .core.user_proxy_agent import UserProxyAgent # noqa: F401
|
||||
from .memory.gpts_memory import GptsMemory # noqa: F401
|
||||
from .resource.resource_api import AgentResource, ResourceType # noqa: F401
|
||||
from .resource.resource_loader import ResourceLoader # noqa: F401
|
||||
from .util.llm.llm import LLMConfig # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"Agent",
|
||||
|
@@ -1 +0,0 @@
|
||||
"""Actions of Agent."""
|
@@ -1 +1,22 @@
|
||||
"""Core Module for the Agent."""
|
||||
"""Core Module for the Agent.
|
||||
|
||||
There are four modules in DB-GPT agent core according the paper
|
||||
`A survey on large language model based autonomous agents
|
||||
<https://link.springer.com/article/10.1007/s11704-024-40231-1>`
|
||||
by `Lei Wang, Chen Ma, Xueyang Feng, et al.`:
|
||||
|
||||
1. Profiling Module: The profiling module aims to indicate the profiles of the agent
|
||||
roles.
|
||||
|
||||
2. Memory Module: It stores information perceived from the environment and leverages
|
||||
the recorded memories to facilitate future actions.
|
||||
|
||||
3. Planning Module: When faced with a complex task, humans tend to deconstruct it into
|
||||
simpler subtasks and solve them individually. The planning module aims to empower the
|
||||
agents with such human capability, which is expected to make the agent behave more
|
||||
reasonably, powerfully, and reliably
|
||||
|
||||
4. Action Module: The action module is responsible for translating the agent’s
|
||||
decisions into specific outcomes. This module is located at the most downstream
|
||||
position and directly interacts with the environment.
|
||||
"""
|
||||
|
20
dbgpt/agent/core/action/__init__.py
Normal file
20
dbgpt/agent/core/action/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Action Module.
|
||||
|
||||
The action module is responsible for translating the agent’s decisions into specific
|
||||
outcomes. This module is located at the most downstream position and directly interacts
|
||||
with the environment. It is influenced by the profile, memory, and planning modules.
|
||||
|
||||
|
||||
The Goal Of The Action Module:
|
||||
--------
|
||||
1. Task Completion: Complete specific tasks, write a function in software development,
|
||||
and make an iron pick in the game.
|
||||
|
||||
2. Communication: Communicate with other agents.
|
||||
|
||||
3. Environment exploration: Explore unfamiliar environments to expand its perception
|
||||
and strike a balance between exploring and exploiting.
|
||||
"""
|
||||
|
||||
from .base import Action, ActionOutput # noqa: F401
|
||||
from .blank_action import BlankAction # noqa: F401
|
@@ -1,4 +1,5 @@
|
||||
"""Base Action class for defining agent actions."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
@@ -21,12 +22,13 @@ from dbgpt._private.pydantic import (
|
||||
field_description,
|
||||
model_fields,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
from dbgpt.vis.base import Vis
|
||||
|
||||
from ...vis.base import Vis
|
||||
from ..resource.resource_api import AgentResource, ResourceType
|
||||
from ..resource.resource_loader import ResourceLoader
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_loader import ResourceLoader
|
||||
|
||||
T = TypeVar("T", bound=Union[BaseModel, List[BaseModel], None])
|
||||
|
||||
@@ -41,6 +43,20 @@ class ActionOutput(BaseModel):
|
||||
view: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
resource_value: Optional[Any] = None
|
||||
action: Optional[str] = None
|
||||
thoughts: Optional[str] = None
|
||||
observations: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Any) -> Any:
|
||||
"""Pre-fill the values."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
is_exe_success = values.get("is_exe_success", True)
|
||||
if not is_exe_success and "observations" not in values:
|
||||
values["observations"] = values.get("content")
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
@@ -3,8 +3,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from ..resource.resource_api import AgentResource
|
||||
from .action import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource
|
||||
from .base import Action, ActionOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,4 +33,8 @@ class BlankAction(Action):
|
||||
|
||||
Just return the AI message.
|
||||
"""
|
||||
return ActionOutput(is_exe_success=True, content=ai_message, view=ai_message)
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=ai_message,
|
||||
view=ai_message,
|
||||
)
|
@@ -9,9 +9,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from ..memory.gpts_memory import GptsMemory
|
||||
from ..resource.resource_loader import ResourceLoader
|
||||
from .action.base import ActionOutput
|
||||
from .memory.agent_memory import AgentMemory
|
||||
|
||||
|
||||
class Agent(ABC):
|
||||
@@ -160,17 +160,20 @@ class Agent(ABC):
|
||||
verification result.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Return name of the agent."""
|
||||
def name(self) -> str:
|
||||
"""Return the name of the agent."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_profile(self) -> str:
|
||||
"""Return profile of the agent."""
|
||||
def role(self) -> str:
|
||||
"""Return the role of the agent."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_describe(self) -> str:
|
||||
"""Return describe of the agent."""
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Return the description of the agent."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -204,7 +207,7 @@ class AgentGenerateContext:
|
||||
rely_messages: List[AgentMessage] = dataclasses.field(default_factory=list)
|
||||
final: Optional[bool] = True
|
||||
|
||||
memory: Optional[GptsMemory] = None
|
||||
memory: Optional[AgentMemory] = None
|
||||
agent_context: Optional[AgentContext] = None
|
||||
resource_loader: Optional[ResourceLoader] = None
|
||||
llm_client: Optional[LLMClient] = None
|
||||
@@ -302,3 +305,9 @@ class AgentMessage:
|
||||
role=self.role,
|
||||
success=self.success,
|
||||
)
|
||||
|
||||
def get_dict_context(self) -> Dict[str, Any]:
|
||||
"""Return the context as a dictionary."""
|
||||
if isinstance(self.context, dict):
|
||||
return self.context
|
||||
return {}
|
||||
|
@@ -18,7 +18,7 @@ def participant_roles(agents: List[Agent]) -> str:
|
||||
# Default to all agents registered
|
||||
roles = []
|
||||
for agent in agents:
|
||||
roles.append(f"{agent.get_name()}: {agent.get_describe()}")
|
||||
roles.append(f"{agent.name}: {agent.desc}")
|
||||
return "\n".join(roles)
|
||||
|
||||
|
||||
@@ -34,13 +34,13 @@ def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict:
|
||||
mentions = dict()
|
||||
for agent in agents:
|
||||
regex = (
|
||||
r"(?<=\W)" + re.escape(agent.get_name()) + r"(?=\W)"
|
||||
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
|
||||
) # Finds agent mentions, taking word boundaries into account
|
||||
count = len(
|
||||
re.findall(regex, " " + message_content + " ")
|
||||
) # Pad the message to help with matching
|
||||
if count > 0:
|
||||
mentions[agent.get_name()] = count
|
||||
mentions[agent.name] = count
|
||||
return mentions
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class AgentManager(BaseComponent):
|
||||
) -> str:
|
||||
"""Register an agent."""
|
||||
inst = cls()
|
||||
profile = inst.get_profile()
|
||||
profile = inst.role
|
||||
if profile in self._agents and (
|
||||
profile in self._core_agents or not ignore_duplicate
|
||||
):
|
||||
@@ -110,13 +110,13 @@ class AgentManager(BaseComponent):
|
||||
|
||||
def get_describe_by_name(self, name: str) -> str:
|
||||
"""Return the description of an agent by name."""
|
||||
return self._agents[name][1].desc
|
||||
return self._agents[name][1].desc or ""
|
||||
|
||||
def all_agents(self) -> Dict[str, str]:
|
||||
"""Return a dictionary of all registered agents and their descriptions."""
|
||||
result = {}
|
||||
for name, value in self._agents.items():
|
||||
result[name] = value[1].desc
|
||||
result[name] = value[1].desc or ""
|
||||
return result
|
||||
|
||||
def list_agents(self):
|
||||
@@ -125,7 +125,7 @@ class AgentManager(BaseComponent):
|
||||
for name, value in self._agents.items():
|
||||
result.append(
|
||||
{
|
||||
"name": value[1].profile,
|
||||
"name": value[1].role,
|
||||
"desc": value[1].goal,
|
||||
}
|
||||
)
|
||||
|
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import LLMClient, ModelMessageRoleType
|
||||
@@ -11,14 +11,15 @@ from dbgpt.util.error_types import LLMChatError
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
from dbgpt.util.utils import colored
|
||||
|
||||
from ..actions.action import Action, ActionOutput
|
||||
from ..memory.base import GptsMessage
|
||||
from ..memory.gpts_memory import GptsMemory
|
||||
from ..resource.resource_api import AgentResource, ResourceClient
|
||||
from ..resource.resource_loader import ResourceLoader
|
||||
from ..util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ..util.llm.llm_client import AIWrapper
|
||||
from .action.base import Action, ActionOutput
|
||||
from .agent import Agent, AgentContext, AgentMessage, AgentReviewInfo
|
||||
from .llm.llm import LLMConfig, LLMStrategyType
|
||||
from .llm.llm_client import AIWrapper
|
||||
from .memory.agent_memory import AgentMemory
|
||||
from .memory.gpts.base import GptsMessage
|
||||
from .memory.gpts.gpts_memory import GptsMemory
|
||||
from .role import Role
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,26 +34,16 @@ class ConversableAgent(Role, Agent):
|
||||
actions: List[Action] = Field(default_factory=list)
|
||||
resources: List[AgentResource] = Field(default_factory=list)
|
||||
llm_config: Optional[LLMConfig] = None
|
||||
memory: GptsMemory = Field(default_factory=GptsMemory)
|
||||
resource_loader: Optional[ResourceLoader] = None
|
||||
max_retry_count: int = 3
|
||||
consecutive_auto_reply_counter: int = 0
|
||||
llm_client: Optional[AIWrapper] = None
|
||||
oai_system_message: List[Dict] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new agent."""
|
||||
Role.__init__(self, **kwargs)
|
||||
Agent.__init__(self)
|
||||
|
||||
def init_system_message(self) -> None:
|
||||
"""Initialize the system message."""
|
||||
content = self.prompt_template()
|
||||
# TODO: Don't modify the original data, need to be optimized
|
||||
self.oai_system_message = [
|
||||
{"content": content, "role": ModelMessageRoleType.SYSTEM}
|
||||
]
|
||||
|
||||
def check_available(self) -> None:
|
||||
"""Check if the agent is available.
|
||||
|
||||
@@ -63,7 +54,7 @@ class ConversableAgent(Role, Agent):
|
||||
# check run context
|
||||
if self.agent_context is None:
|
||||
raise ValueError(
|
||||
f"{self.name}[{self.profile}] Missing context in which agent is "
|
||||
f"{self.name}[{self.role}] Missing context in which agent is "
|
||||
f"running!"
|
||||
)
|
||||
|
||||
@@ -90,20 +81,20 @@ class ConversableAgent(Role, Agent):
|
||||
and action.resource_need not in have_resource_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.name}[{self.profile}] Missing resources required for "
|
||||
f"{self.name}[{self.role}] Missing resources required for "
|
||||
"runtime!"
|
||||
)
|
||||
else:
|
||||
if not self.is_human and not self.is_team:
|
||||
raise ValueError(
|
||||
f"This agent {self.name}[{self.profile}] is missing action modules."
|
||||
f"This agent {self.name}[{self.role}] is missing action modules."
|
||||
)
|
||||
# llm check
|
||||
if not self.is_human and (
|
||||
self.llm_config is None or self.llm_config.llm_client is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.name}[{self.profile}] Model configuration is missing or model "
|
||||
f"{self.name}[{self.role}] Model configuration is missing or model "
|
||||
"service is unavailable!"
|
||||
)
|
||||
|
||||
@@ -161,14 +152,19 @@ class ConversableAgent(Role, Agent):
|
||||
for action in self.actions:
|
||||
action.init_resource_loader(self.resource_loader)
|
||||
|
||||
# Initialize system messages
|
||||
self.init_system_message()
|
||||
|
||||
# Initialize LLM Server
|
||||
if not self.is_human:
|
||||
if not self.llm_config or not self.llm_config.llm_client:
|
||||
raise ValueError("LLM client is not initialized!")
|
||||
self.llm_client = AIWrapper(llm_client=self.llm_config.llm_client)
|
||||
self.memory.initialize(
|
||||
self.name,
|
||||
self.llm_config.llm_client,
|
||||
importance_scorer=self.memory_importance_scorer,
|
||||
insight_extractor=self.memory_insight_extractor,
|
||||
)
|
||||
# Clone the memory structure
|
||||
self.memory = self.memory.structure_clone()
|
||||
return self
|
||||
|
||||
def bind(self, target: Any) -> "ConversableAgent":
|
||||
@@ -176,7 +172,7 @@ class ConversableAgent(Role, Agent):
|
||||
if isinstance(target, LLMConfig):
|
||||
self.llm_config = target
|
||||
elif isinstance(target, GptsMemory):
|
||||
self.memory = target
|
||||
raise ValueError("GptsMemory is not supported!")
|
||||
elif isinstance(target, AgentContext):
|
||||
self.agent_context = target
|
||||
elif isinstance(target, ResourceLoader):
|
||||
@@ -186,6 +182,8 @@ class ConversableAgent(Role, Agent):
|
||||
self.actions.extend(target)
|
||||
elif _is_list_of_type(target, AgentResource):
|
||||
self.resources = target
|
||||
elif isinstance(target, AgentMemory):
|
||||
self.memory = target
|
||||
return self
|
||||
|
||||
async def send(
|
||||
@@ -200,9 +198,9 @@ class ConversableAgent(Role, Agent):
|
||||
with root_tracer.start_span(
|
||||
"agent.send",
|
||||
metadata={
|
||||
"sender": self.get_name(),
|
||||
"recipient": recipient.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": self.name,
|
||||
"recipient": recipient.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"agent_message": message.to_dict(),
|
||||
"request_reply": request_reply,
|
||||
"is_recovery": is_recovery,
|
||||
@@ -230,9 +228,9 @@ class ConversableAgent(Role, Agent):
|
||||
with root_tracer.start_span(
|
||||
"agent.receive",
|
||||
metadata={
|
||||
"sender": sender.get_name(),
|
||||
"recipient": self.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"recipient": self.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"agent_message": message.to_dict(),
|
||||
"request_reply": request_reply,
|
||||
"silent": silent,
|
||||
@@ -271,14 +269,14 @@ class ConversableAgent(Role, Agent):
|
||||
root_span = root_tracer.start_span(
|
||||
"agent.generate_reply",
|
||||
metadata={
|
||||
"sender": sender.get_name(),
|
||||
"recipient": self.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"recipient": self.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"received_message": received_message.to_dict(),
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
"rely_messages": [msg.to_dict() for msg in rely_messages]
|
||||
if rely_messages
|
||||
else None,
|
||||
"rely_messages": (
|
||||
[msg.to_dict() for msg in rely_messages] if rely_messages else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -295,18 +293,6 @@ class ConversableAgent(Role, Agent):
|
||||
)
|
||||
span.metadata["reply_message"] = reply_message.to_dict()
|
||||
|
||||
with root_tracer.start_span(
|
||||
"agent.generate_reply._system_message_assembly",
|
||||
metadata={
|
||||
"reply_message": reply_message.to_dict(),
|
||||
},
|
||||
) as span:
|
||||
# assemble system message
|
||||
await self._system_message_assembly(
|
||||
received_message.content, reply_message.context
|
||||
)
|
||||
span.metadata["assembled_system_messages"] = self.oai_system_message
|
||||
|
||||
fail_reason = None
|
||||
current_retry_counter = 0
|
||||
is_success = True
|
||||
@@ -325,8 +311,11 @@ class ConversableAgent(Role, Agent):
|
||||
retry_message, self, reviewer, request_reply=False
|
||||
)
|
||||
|
||||
thinking_messages = self._load_thinking_messages(
|
||||
received_message, sender, rely_messages
|
||||
thinking_messages = await self._load_thinking_messages(
|
||||
received_message,
|
||||
sender,
|
||||
rely_messages,
|
||||
context=reply_message.get_dict_context(),
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"agent.generate_reply.thinking",
|
||||
@@ -345,7 +334,7 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
with root_tracer.start_span(
|
||||
"agent.generate_reply.review",
|
||||
metadata={"llm_reply": llm_reply, "censored": self.get_name()},
|
||||
metadata={"llm_reply": llm_reply, "censored": self.name},
|
||||
) as span:
|
||||
# 2.Review whether what is being done is legal
|
||||
approve, comments = await self.review(llm_reply, self)
|
||||
@@ -361,8 +350,8 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.generate_reply.act",
|
||||
metadata={
|
||||
"llm_reply": llm_reply,
|
||||
"sender": sender.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"act_extent_param": act_extent_param,
|
||||
},
|
||||
) as span:
|
||||
@@ -383,8 +372,8 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.generate_reply.verify",
|
||||
metadata={
|
||||
"llm_reply": llm_reply,
|
||||
"sender": sender.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
},
|
||||
) as span:
|
||||
# 4.Reply information verification
|
||||
@@ -394,6 +383,9 @@ class ConversableAgent(Role, Agent):
|
||||
is_success = check_pass
|
||||
span.metadata["check_pass"] = check_pass
|
||||
span.metadata["reason"] = reason
|
||||
|
||||
question: str = received_message.content or ""
|
||||
ai_message: str = llm_reply or ""
|
||||
# 5.Optimize wrong answers myself
|
||||
if not check_pass:
|
||||
current_retry_counter += 1
|
||||
@@ -403,7 +395,20 @@ class ConversableAgent(Role, Agent):
|
||||
reply_message, sender, reviewer, request_reply=False
|
||||
)
|
||||
fail_reason = reason
|
||||
await self.save_to_memory(
|
||||
question=question,
|
||||
ai_message=ai_message,
|
||||
action_output=act_out,
|
||||
check_pass=check_pass,
|
||||
check_fail_reason=fail_reason,
|
||||
)
|
||||
else:
|
||||
await self.save_to_memory(
|
||||
question=question,
|
||||
ai_message=ai_message,
|
||||
action_output=act_out,
|
||||
check_pass=check_pass,
|
||||
)
|
||||
break
|
||||
reply_message.success = is_success
|
||||
return reply_message
|
||||
@@ -437,8 +442,6 @@ class ConversableAgent(Role, Agent):
|
||||
try:
|
||||
if prompt:
|
||||
llm_messages = _new_system_message(prompt) + llm_messages
|
||||
else:
|
||||
llm_messages = self.oai_system_message + llm_messages
|
||||
|
||||
if not self.llm_client:
|
||||
raise ValueError("LLM client is not initialized!")
|
||||
@@ -491,9 +494,9 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.act.run",
|
||||
metadata={
|
||||
"message": message,
|
||||
"sender": sender.get_name() if sender else None,
|
||||
"recipient": self.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name if sender else None,
|
||||
"recipient": self.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"need_resource": need_resource.to_dict() if need_resource else None,
|
||||
"rely_action_out": last_out.to_dict() if last_out else None,
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
@@ -563,9 +566,9 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.initiate_chat",
|
||||
span_type=SpanType.AGENT,
|
||||
metadata={
|
||||
"sender": self.get_name(),
|
||||
"recipient": recipient.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": self.name,
|
||||
"recipient": recipient.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"agent_message": agent_message.to_dict(),
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
},
|
||||
@@ -612,21 +615,27 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
gpts_message: GptsMessage = GptsMessage(
|
||||
conv_id=self.not_null_agent_context.conv_id,
|
||||
sender=sender.get_profile(),
|
||||
receiver=self.profile,
|
||||
sender=sender.role,
|
||||
receiver=self.role,
|
||||
role=role,
|
||||
rounds=self.consecutive_auto_reply_counter,
|
||||
current_goal=oai_message.get("current_goal", None),
|
||||
content=oai_message.get("content", None),
|
||||
context=json.dumps(oai_message["context"], ensure_ascii=False)
|
||||
if "context" in oai_message
|
||||
else None,
|
||||
review_info=json.dumps(oai_message["review_info"], ensure_ascii=False)
|
||||
if "review_info" in oai_message
|
||||
else None,
|
||||
action_report=json.dumps(oai_message["action_report"], ensure_ascii=False)
|
||||
if "action_report" in oai_message
|
||||
else None,
|
||||
context=(
|
||||
json.dumps(oai_message["context"], ensure_ascii=False)
|
||||
if "context" in oai_message
|
||||
else None
|
||||
),
|
||||
review_info=(
|
||||
json.dumps(oai_message["review_info"], ensure_ascii=False)
|
||||
if "review_info" in oai_message
|
||||
else None
|
||||
),
|
||||
action_report=(
|
||||
json.dumps(oai_message["action_report"], ensure_ascii=False)
|
||||
if "action_report" in oai_message
|
||||
else None
|
||||
),
|
||||
model_name=oai_message.get("model_name", None),
|
||||
)
|
||||
|
||||
@@ -643,10 +652,10 @@ class ConversableAgent(Role, Agent):
|
||||
def _print_received_message(self, message: AgentMessage, sender: Agent):
|
||||
# print the message received
|
||||
print("\n", "-" * 80, flush=True, sep="")
|
||||
_print_name = self.name if self.name else self.profile
|
||||
_print_name = self.name if self.name else self.role
|
||||
print(
|
||||
colored(
|
||||
sender.get_name() if sender.get_name() else sender.get_profile(),
|
||||
sender.name if sender.name else sender.role,
|
||||
"yellow",
|
||||
),
|
||||
"(to",
|
||||
@@ -660,7 +669,7 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
review_info = message.review_info
|
||||
if review_info:
|
||||
name = sender.get_name() if sender.get_name() else sender.get_profile()
|
||||
name = sender.name if sender.name else sender.role
|
||||
pass_msg = "Pass" if review_info.approve else "Reject"
|
||||
review_msg = f"{pass_msg}({review_info.comments})"
|
||||
approve_print = f">>>>>>>>{name} Review info: \n{review_msg}"
|
||||
@@ -668,7 +677,7 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
action_report = message.action_report
|
||||
if action_report:
|
||||
name = sender.get_name() if sender.get_name() else sender.get_profile()
|
||||
name = sender.name if sender.name else sender.role
|
||||
action_msg = (
|
||||
"execution succeeded"
|
||||
if action_report["is_exe_success"]
|
||||
@@ -690,42 +699,32 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
self._print_received_message(message, sender)
|
||||
|
||||
async def _system_message_assembly(
|
||||
self, question: Optional[str], context: Optional[Union[str, Dict]] = None
|
||||
):
|
||||
# system message
|
||||
self.init_system_message()
|
||||
if len(self.oai_system_message) > 0:
|
||||
resource_prompt_list = []
|
||||
for item in self.resources:
|
||||
resource_client = self.not_null_resource_loader.get_resource_api(
|
||||
item.type, ResourceClient
|
||||
async def generate_resource_variables(
|
||||
self, question: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate the resource variables."""
|
||||
resource_prompt_list = []
|
||||
for item in self.resources:
|
||||
resource_client = self.not_null_resource_loader.get_resource_api(
|
||||
item.type, ResourceClient
|
||||
)
|
||||
if not resource_client:
|
||||
raise ValueError(
|
||||
f"Resource {item.type}:{item.value} missing resource loader"
|
||||
f" implementation,unable to read resources!"
|
||||
)
|
||||
if not resource_client:
|
||||
raise ValueError(
|
||||
f"Resource {item.type}:{item.value} missing resource loader"
|
||||
f" implementation,unable to read resources!"
|
||||
)
|
||||
resource_prompt_list.append(
|
||||
await resource_client.get_resource_prompt(item, question)
|
||||
)
|
||||
if context is None or not isinstance(context, dict):
|
||||
context = {}
|
||||
resource_prompt_list.append(
|
||||
await resource_client.get_resource_prompt(item, question)
|
||||
)
|
||||
|
||||
resource_prompt = ""
|
||||
if len(resource_prompt_list) > 0:
|
||||
resource_prompt = "RESOURCES:" + "\n".join(resource_prompt_list)
|
||||
resource_prompt = ""
|
||||
if len(resource_prompt_list) > 0:
|
||||
resource_prompt = "RESOURCES:" + "\n".join(resource_prompt_list)
|
||||
|
||||
out_schema: Optional[str] = ""
|
||||
if self.actions and len(self.actions) > 0:
|
||||
out_schema = self.actions[0].ai_out_schema
|
||||
for message in self.oai_system_message:
|
||||
new_content = message["content"].format(
|
||||
resource_prompt=resource_prompt,
|
||||
out_schema=out_schema,
|
||||
**context,
|
||||
)
|
||||
message["content"] = new_content
|
||||
out_schema: Optional[str] = ""
|
||||
if self.actions and len(self.actions) > 0:
|
||||
out_schema = self.actions[0].ai_out_schema
|
||||
return {"resource_prompt": resource_prompt, "out_schema": out_schema}
|
||||
|
||||
def _excluded_models(
|
||||
self,
|
||||
@@ -774,7 +773,7 @@ class ConversableAgent(Role, Agent):
|
||||
else:
|
||||
raise ValueError("No model service available!")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.profile} get next llm failed!{str(e)}")
|
||||
logger.error(f"{self.role} get next llm failed!{str(e)}")
|
||||
raise ValueError(f"Failed to allocate model service,{str(e)}!")
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
@@ -803,9 +802,9 @@ class ConversableAgent(Role, Agent):
|
||||
if item.role:
|
||||
role = item.role
|
||||
else:
|
||||
if item.receiver == self.profile:
|
||||
if item.receiver == self.role:
|
||||
role = ModelMessageRoleType.HUMAN
|
||||
elif item.sender == self.profile:
|
||||
elif item.sender == self.role:
|
||||
role = ModelMessageRoleType.AI
|
||||
else:
|
||||
continue
|
||||
@@ -825,14 +824,80 @@ class ConversableAgent(Role, Agent):
|
||||
AgentMessage(
|
||||
content=content,
|
||||
role=role,
|
||||
context=json.loads(item.context)
|
||||
if item.context is not None
|
||||
else None,
|
||||
context=(
|
||||
json.loads(item.context) if item.context is not None else None
|
||||
),
|
||||
)
|
||||
)
|
||||
return oai_messages
|
||||
|
||||
def _load_thinking_messages(
|
||||
async def _load_thinking_messages(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[AgentMessage]:
|
||||
observation = received_message.content
|
||||
if not observation:
|
||||
raise ValueError("The received message content is empty!")
|
||||
memories = await self.read_memories(observation)
|
||||
reply_message_str = ""
|
||||
if context is None:
|
||||
context = {}
|
||||
if rely_messages:
|
||||
copied_rely_messages = [m.copy() for m in rely_messages]
|
||||
# When directly relying on historical messages, use the execution result
|
||||
# content as a dependency
|
||||
for message in copied_rely_messages:
|
||||
action_report: Optional[ActionOutput] = ActionOutput.from_dict(
|
||||
message.action_report
|
||||
)
|
||||
if action_report:
|
||||
# TODO: Modify in-place, need to be optimized
|
||||
message.content = action_report.content
|
||||
if message.name != self.role:
|
||||
# TODO, use name
|
||||
# Rely messages are not from the current agent
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
reply_message_str += f"Question: {message.content}\n"
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
reply_message_str += f"Observation: {message.content}\n"
|
||||
if reply_message_str:
|
||||
memories += "\n" + reply_message_str
|
||||
|
||||
system_prompt = await self.build_prompt(
|
||||
question=observation,
|
||||
is_system=True,
|
||||
most_recent_memories=memories,
|
||||
**context,
|
||||
)
|
||||
user_prompt = await self.build_prompt(
|
||||
question=observation,
|
||||
is_system=False,
|
||||
most_recent_memories=memories,
|
||||
**context,
|
||||
)
|
||||
|
||||
agent_messages = []
|
||||
if system_prompt:
|
||||
agent_messages.append(
|
||||
AgentMessage(
|
||||
content=system_prompt,
|
||||
role=ModelMessageRoleType.SYSTEM,
|
||||
)
|
||||
)
|
||||
if user_prompt:
|
||||
agent_messages.append(
|
||||
AgentMessage(
|
||||
content=user_prompt,
|
||||
role=ModelMessageRoleType.HUMAN,
|
||||
)
|
||||
)
|
||||
|
||||
return agent_messages
|
||||
|
||||
def _old_load_thinking_messages(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
@@ -846,8 +911,8 @@ class ConversableAgent(Role, Agent):
|
||||
with root_tracer.start_span(
|
||||
"agent._load_thinking_messages",
|
||||
metadata={
|
||||
"sender": sender.get_name(),
|
||||
"recipient": self.get_name(),
|
||||
"sender": sender.name,
|
||||
"recipient": self.name,
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
"current_goal": current_goal,
|
||||
},
|
||||
@@ -855,8 +920,8 @@ class ConversableAgent(Role, Agent):
|
||||
# Get historical information from the memory
|
||||
memory_messages = self.memory.message_memory.get_between_agents(
|
||||
self.not_null_agent_context.conv_id,
|
||||
self.profile,
|
||||
sender.get_profile(),
|
||||
self.role,
|
||||
sender.role,
|
||||
current_goal,
|
||||
)
|
||||
span.metadata["memory_messages"] = [
|
||||
|
@@ -1,13 +1,14 @@
|
||||
"""Base classes for managing a group of agents in a team chat."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from .action.base import ActionOutput
|
||||
from .agent import Agent, AgentMessage
|
||||
from .base_agent import ConversableAgent
|
||||
from .profile import ProfileConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,7 +87,7 @@ class Team(BaseModel):
|
||||
@property
|
||||
def agent_names(self) -> List[str]:
|
||||
"""Return the names of the agents in the group chat."""
|
||||
return [agent.get_profile() for agent in self.agents]
|
||||
return [agent.role for agent in self.agents]
|
||||
|
||||
def agent_by_name(self, name: str) -> Agent:
|
||||
"""Return the agent with a given name."""
|
||||
@@ -121,10 +122,14 @@ class ManagerAgent(ConversableAgent, Team):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = "TeamManager"
|
||||
goal: str = "manage all hired intelligent agents to complete mission objectives"
|
||||
constraints: List[str] = []
|
||||
desc: str = goal
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="ManagerAgent",
|
||||
profile="TeamManager",
|
||||
goal="manage all hired intelligent agents to complete mission objectives",
|
||||
constraints=[],
|
||||
desc="manage all hired intelligent agents to complete mission objectives",
|
||||
)
|
||||
|
||||
is_team: bool = True
|
||||
|
||||
# The management agent does not need to retry the exception. The actual execution
|
||||
@@ -149,6 +154,16 @@ class ManagerAgent(ConversableAgent, Team):
|
||||
self.messages.append(message.to_llm_message())
|
||||
return message.content, None
|
||||
|
||||
async def _load_thinking_messages(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[AgentMessage]:
|
||||
"""Load messages for thinking."""
|
||||
return [AgentMessage(content=received_message.content)]
|
||||
|
||||
async def act(
|
||||
self,
|
||||
message: Optional[str],
|
||||
|
16
dbgpt/agent/core/memory/__init__.py
Normal file
16
dbgpt/agent/core/memory/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Memory module for the agent."""
|
||||
|
||||
from .agent_memory import AgentMemory, AgentMemoryFragment # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
InsightMemoryFragment,
|
||||
Memory,
|
||||
MemoryFragment,
|
||||
SensoryMemory,
|
||||
ShortTermMemory,
|
||||
)
|
||||
from .hybrid import HybridMemory # noqa: F401
|
||||
from .llm import LLMImportanceScorer, LLMInsightExtractor # noqa: F401
|
||||
from .long_term import LongTermMemory, LongTermRetriever # noqa: F401
|
||||
from .short_term import EnhancedShortTermMemory # noqa: F401
|
282
dbgpt/agent/core/memory/agent_memory.py
Normal file
282
dbgpt/agent/core/memory/agent_memory.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Agent memory module."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Callable, List, Optional, Type, cast
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.id_generator import new_id
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
Memory,
|
||||
MemoryFragment,
|
||||
ShortTermMemory,
|
||||
WriteOperation,
|
||||
)
|
||||
from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory
|
||||
|
||||
|
||||
class AgentMemoryFragment(MemoryFragment):
|
||||
"""Default memory fragment for agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
is_insight: bool = False,
|
||||
):
|
||||
"""Create a memory fragment."""
|
||||
if not memory_id:
|
||||
# Generate a new memory id, we use snowflake id generator here.
|
||||
memory_id = new_id()
|
||||
self.observation = observation
|
||||
self._embeddings = embeddings
|
||||
self.memory_id: int = cast(int, memory_id)
|
||||
self._importance: Optional[float] = importance
|
||||
self._last_accessed_time: Optional[datetime] = last_accessed_time
|
||||
self._is_insight = is_insight
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
"""Return the memory id."""
|
||||
return self.memory_id
|
||||
|
||||
@property
|
||||
def raw_observation(self) -> str:
|
||||
"""Return the raw observation."""
|
||||
return self.observation
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[List[float]]:
|
||||
"""Return the embeddings of the memory fragment."""
|
||||
return self._embeddings
|
||||
|
||||
def update_embeddings(self, embeddings: List[float]) -> None:
|
||||
"""Update the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embeddings(List[float]): embeddings
|
||||
"""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def calculate_current_embeddings(
|
||||
self, embedding_func: Callable[[List[str]], List[List[float]]]
|
||||
) -> List[float]:
|
||||
"""Calculate the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embedding_func(Callable[[List[str]], List[List[float]]]): Function to
|
||||
compute embeddings
|
||||
|
||||
Returns:
|
||||
List[float]: Embeddings of the memory fragment
|
||||
"""
|
||||
embeddings = embedding_func([self.observation])
|
||||
return embeddings[0]
|
||||
|
||||
@property
|
||||
def is_insight(self) -> bool:
|
||||
"""Return whether the memory fragment is an insight.
|
||||
|
||||
Returns:
|
||||
bool: Whether the memory fragment is an insight
|
||||
"""
|
||||
return self._is_insight
|
||||
|
||||
@property
|
||||
def importance(self) -> Optional[float]:
|
||||
"""Return the importance of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[float]: Importance of the memory fragment
|
||||
"""
|
||||
return self._importance
|
||||
|
||||
def update_importance(self, importance: float) -> Optional[float]:
|
||||
"""Update the importance of the memory fragment.
|
||||
|
||||
Args:
|
||||
importance(float): Importance of the memory fragment
|
||||
|
||||
Returns:
|
||||
Optional[float]: Old importance
|
||||
"""
|
||||
old_importance = self._importance
|
||||
self._importance = importance
|
||||
return old_importance
|
||||
|
||||
@property
|
||||
def last_accessed_time(self) -> Optional[datetime]:
|
||||
"""Return the last accessed time of the memory fragment.
|
||||
|
||||
Used to determine the least recently used memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: Last accessed time
|
||||
"""
|
||||
return self._last_accessed_time
|
||||
|
||||
def update_accessed_time(self, now: datetime) -> Optional[datetime]:
|
||||
"""Update the last accessed time of the memory fragment.
|
||||
|
||||
Args:
|
||||
now(datetime): Current time
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: Old last accessed time
|
||||
"""
|
||||
old_time = self._last_accessed_time
|
||||
self._last_accessed_time = now
|
||||
return old_time
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
cls: Type["AgentMemoryFragment"],
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
is_insight: bool = False,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
**kwargs
|
||||
) -> "AgentMemoryFragment":
|
||||
"""Build a memory fragment from the given parameters."""
|
||||
return cls(
|
||||
observation=observation,
|
||||
embeddings=embeddings,
|
||||
memory_id=memory_id,
|
||||
importance=importance,
|
||||
last_accessed_time=last_accessed_time,
|
||||
is_insight=is_insight,
|
||||
)
|
||||
|
||||
def copy(self: "AgentMemoryFragment") -> "AgentMemoryFragment":
|
||||
"""Return a copy of the memory fragment."""
|
||||
return AgentMemoryFragment.build_from(
|
||||
observation=self.observation,
|
||||
embeddings=self._embeddings,
|
||||
memory_id=self.memory_id,
|
||||
importance=self.importance,
|
||||
last_accessed_time=self.last_accessed_time,
|
||||
is_insight=self.is_insight,
|
||||
)
|
||||
|
||||
|
||||
class AgentMemory(Memory[AgentMemoryFragment]):
|
||||
"""Agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory: Optional[Memory[AgentMemoryFragment]] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[AgentMemoryFragment]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[AgentMemoryFragment]] = None,
|
||||
gpts_memory: Optional[GptsMemory] = None,
|
||||
):
|
||||
"""Create an agent memory.
|
||||
|
||||
Args:
|
||||
memory(Memory[AgentMemoryFragment]): Memory to store fragments
|
||||
importance_scorer(ImportanceScorer[AgentMemoryFragment]): Scorer to
|
||||
calculate the importance of memory fragments
|
||||
insight_extractor(InsightExtractor[AgentMemoryFragment]): Extractor to
|
||||
extract insights from memory fragments
|
||||
gpts_memory(GptsMemory): Memory to store GPTs related information
|
||||
"""
|
||||
if not memory:
|
||||
memory = ShortTermMemory(buffer_size=5)
|
||||
if not gpts_memory:
|
||||
gpts_memory = GptsMemory()
|
||||
self.memory: Memory[AgentMemoryFragment] = cast(
|
||||
Memory[AgentMemoryFragment], memory
|
||||
)
|
||||
self.importance_scorer = importance_scorer
|
||||
self.insight_extractor = insight_extractor
|
||||
self.gpts_memory = gpts_memory
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "AgentMemory", now: Optional[datetime] = None
|
||||
) -> "AgentMemory":
|
||||
"""Return a structure clone of the memory.
|
||||
|
||||
The gpst_memory is not cloned, it will be shared in whole agent memory.
|
||||
"""
|
||||
m = AgentMemory(
|
||||
memory=self.memory.structure_clone(now),
|
||||
importance_scorer=self.importance_scorer,
|
||||
insight_extractor=self.insight_extractor,
|
||||
gpts_memory=self.gpts_memory,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[AgentMemoryFragment]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[AgentMemoryFragment]] = None,
|
||||
real_memory_fragment_class: Optional[Type[AgentMemoryFragment]] = None,
|
||||
) -> None:
|
||||
"""Initialize the memory."""
|
||||
self.memory.initialize(
|
||||
name=name,
|
||||
llm_client=llm_client,
|
||||
importance_scorer=importance_scorer or self.importance_scorer,
|
||||
insight_extractor=insight_extractor or self.insight_extractor,
|
||||
real_memory_fragment_class=real_memory_fragment_class
|
||||
or AgentMemoryFragment,
|
||||
)
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: AgentMemoryFragment,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[AgentMemoryFragment]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
return await self.memory.write(memory_fragment, now)
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[AgentMemoryFragment]:
|
||||
"""Read memory fragments related to the observation.
|
||||
|
||||
Args:
|
||||
observation(str): Observation
|
||||
alpha(float): Importance weight
|
||||
beta(float): Time weight
|
||||
gamma(float): Randomness weight
|
||||
|
||||
Returns:
|
||||
List[AgentMemoryFragment]: List of memory fragments
|
||||
"""
|
||||
return await self.memory.read(observation, alpha, beta, gamma)
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[AgentMemoryFragment]:
|
||||
"""Clear the memory."""
|
||||
return await self.memory.clear()
|
||||
|
||||
@property
|
||||
def plans_memory(self) -> GptsPlansMemory:
|
||||
"""Return the plan memory."""
|
||||
return self.gpts_memory.plans_memory
|
||||
|
||||
@property
|
||||
def message_memory(self) -> GptsMessageMemory:
|
||||
"""Return the message memory."""
|
||||
return self.gpts_memory.message_memory
|
776
dbgpt/agent/core/memory/base.py
Normal file
776
dbgpt/agent/core/memory/base.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""Memory for agent.
|
||||
|
||||
Human memory follows a general progression from sensory memory that registers
|
||||
perceptual inputs, to short-term memory that maintains information transiently, to
|
||||
long-term memory that consolidates information over extended periods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import PublicAPI, immutable, mutable
|
||||
|
||||
T = TypeVar("T", bound="MemoryFragment")
|
||||
M = TypeVar("M", bound="Memory")
|
||||
|
||||
|
||||
class WriteOperation(str, Enum):
|
||||
"""Write operation."""
|
||||
|
||||
ADD = "add"
|
||||
RETRIEVAL = "retrieval"
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class MemoryFragment(ABC):
|
||||
"""Memory fragment interface.
|
||||
|
||||
It is the interface of memory fragment, which is the basic unit of memory, which
|
||||
contains the basic information of memory, such as observation, importance, whether
|
||||
it is insight, last access time, etc
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def build_from(
|
||||
cls: Type[T],
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
is_insight: bool = False,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Build a memory fragment from memory id and observation.
|
||||
|
||||
Args:
|
||||
observation(str): Observation
|
||||
embeddings(List[float], optional): Embeddings of the memory fragment.
|
||||
memory_id(int): Memory id
|
||||
importance(float): Importance
|
||||
is_insight(bool): Whether the memory fragment is an insight
|
||||
last_accessed_time(datetime): Last accessed time
|
||||
|
||||
Returns:
|
||||
MemoryFragment: Memory fragment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> int:
|
||||
"""Return the id of the memory fragment.
|
||||
|
||||
Commonly, the id is generated by Snowflake algorithm. So we can parse the
|
||||
timestamp of when the memory fragment is created.
|
||||
|
||||
Returns:
|
||||
int: id
|
||||
"""
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Return the metadata of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Metadata
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def importance(self) -> Optional[float]:
|
||||
"""Return the importance of the memory fragment.
|
||||
|
||||
It should be noted that importance only reflects the characters of the memory
|
||||
itself.
|
||||
|
||||
Returns:
|
||||
Optional[float]: importance, None means the importance is not available.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def update_importance(self, importance: float) -> Optional[float]:
|
||||
"""Update the importance of the memory fragment.
|
||||
|
||||
Args:
|
||||
importance(float): importance
|
||||
|
||||
Returns:
|
||||
Optional[float]: importance
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def raw_observation(self) -> str:
|
||||
"""Return the raw observation.
|
||||
|
||||
Raw observation is the original observation data, it can be an observation from
|
||||
environment or an observation after executing an action.
|
||||
|
||||
Returns:
|
||||
str: raw observation
|
||||
"""
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[List[float]]:
|
||||
"""Return the embeddings of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[List[float]]: embeddings
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def update_embeddings(self, embeddings: List[float]) -> None:
|
||||
"""Update the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embeddings(List[float]): embeddings
|
||||
"""
|
||||
|
||||
def calculate_current_embeddings(
|
||||
self, embedding_func: Callable[[List[str]], List[List[float]]]
|
||||
) -> List[float]:
|
||||
"""Calculate the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embedding_func(Callable[[List[str]], List[List[float]]]): Function to
|
||||
compute embeddings
|
||||
|
||||
Returns:
|
||||
List[float]: Embeddings of the memory fragment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_insight(self) -> bool:
|
||||
"""Return whether the memory fragment is an insight.
|
||||
|
||||
Returns:
|
||||
bool: whether the memory fragment is an insight.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def last_accessed_time(self) -> Optional[datetime]:
|
||||
"""Return the last accessed time of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: last accessed time
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_accessed_time(self, now: datetime) -> Optional[datetime]:
|
||||
"""Update the last accessed time of the memory fragment.
|
||||
|
||||
Args:
|
||||
now(datetime): The current time
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: The last accessed time
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self: T) -> T:
|
||||
"""Copy the memory fragment."""
|
||||
|
||||
def reduce(self, memory_fragments: List[T], **kwargs) -> T:
|
||||
"""Reduce memory fragments to a single memory fragment.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
T: The reduced memory fragment
|
||||
"""
|
||||
obs = []
|
||||
for memory_fragment in memory_fragments:
|
||||
obs.append(memory_fragment.raw_observation)
|
||||
new_observation = ";".join(obs)
|
||||
return self.current_class.build_from(new_observation, **kwargs) # type: ignore
|
||||
|
||||
@property
|
||||
def current_class(self: T) -> Type[T]:
|
||||
"""Return the current class."""
|
||||
return self.__class__
|
||||
|
||||
|
||||
class InsightMemoryFragment(Generic[T]):
|
||||
"""Insight memory fragment.
|
||||
|
||||
Insight memory fragment is a memory fragment that contains insights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_memory_fragment: Union[T, List[T]],
|
||||
insights: Union[List[T], List[str]],
|
||||
):
|
||||
"""Create an insight memory fragment.
|
||||
|
||||
Insight is also a memory fragment.
|
||||
"""
|
||||
if insights and isinstance(insights[0], str):
|
||||
mf = (
|
||||
original_memory_fragment[0]
|
||||
if isinstance(original_memory_fragment, list)
|
||||
else original_memory_fragment
|
||||
)
|
||||
insights = [
|
||||
mf.current_class.build_from(i, is_insight=True) for i in insights # type: ignore # noqa
|
||||
]
|
||||
self._original_memory_fragment = original_memory_fragment
|
||||
self._insights: List[T] = cast(List[T], insights)
|
||||
|
||||
@property
|
||||
def original_memory_fragment(self) -> Union[T, List[T]]:
|
||||
"""Return the original memory fragment."""
|
||||
return self._original_memory_fragment
|
||||
|
||||
@property
|
||||
def insights(self) -> List[T]:
|
||||
"""Return the insights."""
|
||||
return self._insights
|
||||
|
||||
|
||||
class DiscardedMemoryFragments(Generic[T]):
|
||||
"""Discarded memory fragments.
|
||||
|
||||
Sometimes, we need to discard some memory fragments, there are following cases:
|
||||
1. Memory duplicated, the same/similar action is executed multiple times and the
|
||||
same/similar observation from environment is received.
|
||||
2. Memory overflow. The memory is full and the new memory fragment needs to be
|
||||
written.
|
||||
3. The memory fragment is not important enough.
|
||||
4. Simulation of forgetting mechanism.
|
||||
|
||||
The discarded memory fragments may be transferred to another memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
discarded_memory_fragments: List[T],
|
||||
discarded_insights: Optional[List[InsightMemoryFragment[T]]] = None,
|
||||
):
|
||||
"""Create a discarded memory fragments."""
|
||||
if discarded_insights is None:
|
||||
discarded_insights = []
|
||||
self._discarded_memory_fragments = discarded_memory_fragments
|
||||
self._discarded_insights = discarded_insights
|
||||
|
||||
@property
|
||||
def discarded_memory_fragments(self) -> List[T]:
|
||||
"""Return the discarded memory fragments."""
|
||||
return self._discarded_memory_fragments
|
||||
|
||||
@property
|
||||
def discarded_insights(self) -> List[InsightMemoryFragment[T]]:
|
||||
"""Return the discarded insights."""
|
||||
return self._discarded_insights
|
||||
|
||||
|
||||
class InsightExtractor(ABC, Generic[T]):
|
||||
"""Insight extractor interface.
|
||||
|
||||
Obtain high-level insights from memories.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def extract_insights(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> InsightMemoryFragment[T]:
|
||||
"""Extract insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
InsightMemoryFragment: The insights of the memory fragment.
|
||||
"""
|
||||
|
||||
|
||||
class ImportanceScorer(ABC, Generic[T]):
|
||||
"""Importance scorer interface.
|
||||
|
||||
Score the importance of memories.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def score_importance(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> float:
|
||||
"""Score the importance of memory fragment.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment.
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
float: The importance of the memory fragment.
|
||||
"""
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class Memory(ABC, Generic[T]):
|
||||
"""Memory interface."""
|
||||
|
||||
name: Optional[str] = None
|
||||
llm_client: Optional[LLMClient] = None
|
||||
importance_scorer: Optional[ImportanceScorer] = None
|
||||
insight_extractor: Optional[InsightExtractor] = None
|
||||
_real_memory_fragment_class: Optional[Type[T]] = None
|
||||
importance_weight: float = 0.15
|
||||
|
||||
@mutable
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer] = None,
|
||||
insight_extractor: Optional[InsightExtractor] = None,
|
||||
real_memory_fragment_class: Optional[Type[T]] = None,
|
||||
) -> None:
|
||||
"""Initialize memory.
|
||||
|
||||
Some agent may need to initialize memory before using it.
|
||||
"""
|
||||
self.name = name
|
||||
self.llm_client = llm_client
|
||||
self.importance_scorer = importance_scorer
|
||||
self.insight_extractor = insight_extractor
|
||||
self._real_memory_fragment_class = real_memory_fragment_class
|
||||
|
||||
@abstractmethod
|
||||
@immutable
|
||||
def structure_clone(self: M, now: Optional[datetime] = None) -> M:
|
||||
"""Return a structure clone of the memory.
|
||||
|
||||
Sometimes, we need to clone the structure of the memory, but not the content.
|
||||
|
||||
There some cases:
|
||||
|
||||
1. When we need to reset the memory, we can use this method to create a new
|
||||
one, and the new memory has the same structure as the old one.
|
||||
2. Create a new agent, the new agent has the same memory structure as the
|
||||
planner.
|
||||
|
||||
Args:
|
||||
now(Optional[datetime]): The current time
|
||||
|
||||
Returns:
|
||||
M: The structure clone of the memory
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@mutable
|
||||
def _copy_from(self, memory: "Memory") -> None:
|
||||
"""Copy memory from another memory.
|
||||
|
||||
Args:
|
||||
memory(Memory): Another memory
|
||||
"""
|
||||
self.name = memory.name
|
||||
self.llm_client = memory.llm_client
|
||||
self.importance_scorer = memory.importance_scorer
|
||||
self.insight_extractor = memory.insight_extractor
|
||||
self._real_memory_fragment_class = memory._real_memory_fragment_class
|
||||
|
||||
@abstractmethod
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to memory.
|
||||
|
||||
Two situations need to be noted here:
|
||||
1. Memory duplicated, the same/similar action is executed multiple times and
|
||||
the same/similar observation from environment is received.
|
||||
|
||||
2.Memory overflow. The memory is full and the new memory fragment needs to be
|
||||
written to memory, the common strategy is to discard some memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
now(Optional[datetime]): The current time
|
||||
op(WriteOperation): Write operation
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded.
|
||||
"""
|
||||
|
||||
@mutable
|
||||
async def write_batch(
|
||||
self, memory_fragments: List[T], now: Optional[datetime] = None
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a batch of memory fragments to memory.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
now(Optional[datetime]): The current time
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
r"""Read memory fragments by observation.
|
||||
|
||||
Usually, there three commonly used criteria for information extraction, that is,
|
||||
the recency, relevance, and importance
|
||||
|
||||
Memories that are more recent, relevant, and important are more likely to be
|
||||
extracted. Formally, we conclude the following equation from existing
|
||||
literature for memory information extraction:
|
||||
|
||||
.. math::
|
||||
|
||||
m^* = \arg\min_{m \in M} \alpha s^{\text{rec}}(q, m) + \\
|
||||
\beta s^{\text{rel}}(q, m) + \gamma s^{\text{imp}}(m), \tag{1}
|
||||
|
||||
Args:
|
||||
observation(str): observation(Query)
|
||||
alpha(float, optional): Recency coefficient. Default is None.
|
||||
beta(float, optional): Relevance coefficient. Default is None.
|
||||
gamma(float, optional): Importance coefficient. Default is None.
|
||||
|
||||
Returns:
|
||||
List[T]: memory fragments
|
||||
"""
|
||||
|
||||
@immutable
|
||||
async def reflect(self, memory_fragments: List[T]) -> List[T]:
|
||||
"""Reflect memory fragments by observation.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): memory fragments to be reflected.
|
||||
|
||||
Returns:
|
||||
List[T]: memory fragments after reflection.
|
||||
"""
|
||||
return memory_fragments
|
||||
|
||||
@immutable
|
||||
async def handle_duplicated(
|
||||
self, memory_fragments: List[T], new_memory_fragments: List[T]
|
||||
) -> List[T]:
|
||||
"""Handle duplicated memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
new_memory_fragments(List[T]): New memory fragments
|
||||
|
||||
Returns:
|
||||
List[T]: The new memory fragments after handling duplicated memory
|
||||
fragments.
|
||||
"""
|
||||
return memory_fragments + new_memory_fragments
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle memory overflow.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
|
||||
Returns:
|
||||
Tuple[List[T], List[T]]: The memory fragments after handling overflow and
|
||||
the discarded memory fragments.
|
||||
"""
|
||||
return memory_fragments, []
|
||||
|
||||
@abstractmethod
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments.
|
||||
|
||||
Returns:
|
||||
List[T]: The all cleared memory fragments.
|
||||
"""
|
||||
|
||||
@immutable
|
||||
async def get_insights(
|
||||
self, memory_fragments: List[T]
|
||||
) -> List[InsightMemoryFragment[T]]:
|
||||
"""Get insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
List[InsightMemoryFragment]: The insights of the memory fragments.
|
||||
"""
|
||||
if not self.insight_extractor:
|
||||
return []
|
||||
# Obtain insights in parallel from memory fragments parallel
|
||||
tasks = []
|
||||
for memory_fragment in memory_fragments:
|
||||
tasks.append(
|
||||
self.insight_extractor.extract_insights(
|
||||
memory_fragment, self.llm_client
|
||||
)
|
||||
)
|
||||
insights = await asyncio.gather(*tasks)
|
||||
result = []
|
||||
for insight in insights:
|
||||
if not insight:
|
||||
continue
|
||||
result.append(insight)
|
||||
if len(result) != len(insights):
|
||||
raise ValueError(
|
||||
"The number of insights is not equal to the number of memory fragments."
|
||||
)
|
||||
return result
|
||||
|
||||
@immutable
|
||||
async def score_memory_importance(self, memory_fragments: List[T]) -> List[float]:
|
||||
"""Score the importance of memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
List[float]: The importance of memory fragments.
|
||||
"""
|
||||
if not self.importance_scorer:
|
||||
return [5 * self.importance_weight for _ in memory_fragments]
|
||||
tasks = []
|
||||
for memory_fragment in memory_fragments:
|
||||
tasks.append(
|
||||
self.importance_scorer.score_importance(
|
||||
memory_fragment, self.llm_client
|
||||
)
|
||||
)
|
||||
result = []
|
||||
for importance in await asyncio.gather(*tasks):
|
||||
real_score = importance * self.importance_weight
|
||||
result.append(real_score)
|
||||
return result
|
||||
|
||||
@property
|
||||
@immutable
|
||||
def real_memory_fragment_class(self) -> Type[T]:
|
||||
"""Return the real memory fragment class."""
|
||||
if not self._real_memory_fragment_class:
|
||||
raise ValueError("The real memory fragment class is not set.")
|
||||
return self._real_memory_fragment_class
|
||||
|
||||
|
||||
class SensoryMemory(Memory, Generic[T]):
|
||||
"""Sensory memory."""
|
||||
|
||||
importance_weight: float = 0.9
|
||||
threshold_to_short_term: float = 0.1
|
||||
|
||||
def __init__(self, buffer_size: int = 0):
|
||||
"""Create a sensory memory."""
|
||||
self._buffer_size = buffer_size
|
||||
self._fragments: List[T] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def structure_clone(
|
||||
self: "SensoryMemory[T]", now: Optional[datetime] = None
|
||||
) -> "SensoryMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: SensoryMemory[T] = SensoryMemory(buffer_size=self._buffer_size)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to sensory memory."""
|
||||
fragments = await self.handle_duplicated(self._fragments, [memory_fragment])
|
||||
discarded_fragments: List[T] = []
|
||||
if len(fragments) > self._buffer_size:
|
||||
fragments, discarded_fragments = await self.handle_overflow(fragments)
|
||||
|
||||
async with self._lock:
|
||||
await self.clear()
|
||||
self._fragments = fragments
|
||||
if not discarded_fragments:
|
||||
return None
|
||||
return DiscardedMemoryFragments(discarded_fragments, [])
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments by observation."""
|
||||
return self._fragments
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle memory overflow.
|
||||
|
||||
For sensory memory, the overflow strategy is to transfer all memory fragments
|
||||
to short-term memory.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
|
||||
Returns:
|
||||
Tuple[List[T], List[T]]: The memory fragments after handling overflow and
|
||||
the discarded memory fragments, the discarded memory fragments should
|
||||
be transferred to short-term memory.
|
||||
"""
|
||||
scores = await self.score_memory_importance(memory_fragments)
|
||||
result = []
|
||||
for i, memory in enumerate(memory_fragments):
|
||||
if scores[i] >= self.threshold_to_short_term:
|
||||
memory.update_importance(scores[i])
|
||||
result.append(memory)
|
||||
return [], result
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments."""
|
||||
# async with self._lock:
|
||||
fragments = self._fragments
|
||||
self._fragments = []
|
||||
return fragments
|
||||
|
||||
|
||||
class ShortTermMemory(Memory, Generic[T]):
|
||||
"""Short term memory.
|
||||
|
||||
All memories are stored in computer memory.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int = 5):
|
||||
"""Create a short-term memory."""
|
||||
self._buffer_size = buffer_size
|
||||
self._fragments: List[T] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def structure_clone(
|
||||
self: "ShortTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "ShortTermMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: ShortTermMemory[T] = ShortTermMemory(buffer_size=self._buffer_size)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to short-term memory.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): New memory fragment
|
||||
now(Optional[datetime]): The current time
|
||||
op(WriteOperation): Write operation
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded. The discarded memory fragments
|
||||
should be transferred and stored in long-term memory.
|
||||
"""
|
||||
fragments = await self.handle_duplicated(self._fragments, [memory_fragment])
|
||||
|
||||
async with self._lock:
|
||||
await self.clear()
|
||||
self._fragments = fragments
|
||||
discarded_memories = await self.transfer_to_long_term(memory_fragment)
|
||||
fragments, discarded_fragments = await self.handle_overflow(self._fragments)
|
||||
self._fragments = fragments
|
||||
return discarded_memories
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments by observation."""
|
||||
return self._fragments
|
||||
|
||||
@mutable
|
||||
async def transfer_to_long_term(
|
||||
self, memory_fragment: T
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Transfer the oldest memories to long-term memory.
|
||||
|
||||
This is a very simple strategy, just transfer the oldest memories to long-term
|
||||
memory.
|
||||
"""
|
||||
if len(self._fragments) > self._buffer_size:
|
||||
overflow_cnt = len(self._fragments) - self._buffer_size
|
||||
# Just keep the most recent memories in short-term memory
|
||||
self._fragments = self._fragments[overflow_cnt:]
|
||||
# Transfer the oldest memories to long-term memory
|
||||
overflow_fragments = self._fragments[:overflow_cnt]
|
||||
insights = await self.get_insights(overflow_fragments)
|
||||
return DiscardedMemoryFragments(overflow_fragments, insights)
|
||||
else:
|
||||
return None
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments."""
|
||||
# async with self._lock:
|
||||
fragments = self._fragments
|
||||
self._fragments = []
|
||||
return fragments
|
||||
|
||||
@property
|
||||
@immutable
|
||||
def short_term_memories(self) -> List[T]:
|
||||
"""Return short-term memories."""
|
||||
return self._fragments
|
19
dbgpt/agent/core/memory/gpts/__init__.py
Normal file
19
dbgpt/agent/core/memory/gpts/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Memory module for GPTS messages and plans.
|
||||
|
||||
It stores the messages and plans generated of multiple agents in the conversation.
|
||||
|
||||
It is different from the agent memory as it is a formatted structure to store the
|
||||
messages and plans, and it can be stored in a database or a file.
|
||||
"""
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
GptsMessage,
|
||||
GptsMessageMemory,
|
||||
GptsPlan,
|
||||
GptsPlansMemory,
|
||||
)
|
||||
from .default_gpts_memory import ( # noqa: F401
|
||||
DefaultGptsMessageMemory,
|
||||
DefaultGptsPlansMemory,
|
||||
)
|
||||
from .gpts_memory import GptsMemory # noqa: F401
|
@@ -1,10 +1,11 @@
|
||||
"""Base memory interface for agents."""
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from ...schema import Status
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
@@ -5,7 +5,7 @@ from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ..core.schema import Status
|
||||
from ...schema import Status
|
||||
from .base import GptsMessage, GptsMessageMemory, GptsPlan, GptsPlansMemory
|
||||
|
||||
|
@@ -1,11 +1,12 @@
|
||||
"""GPTs memory."""
|
||||
|
||||
import json
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.vis.client import VisAgentMessages, VisAgentPlans, vis_client
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from ...action.base import ActionOutput
|
||||
from .base import GptsMessage, GptsMessageMemory, GptsPlansMemory
|
||||
from .default_gpts_memory import DefaultGptsMessageMemory, DefaultGptsPlansMemory
|
||||
|
288
dbgpt/agent/core/memory/hybrid.py
Normal file
288
dbgpt/agent/core/memory/hybrid.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Hybrid memory module.
|
||||
|
||||
This structure explicitly models the human short-term and long-term memories. The
|
||||
short-term memory temporarily buffers recent perceptions, while long-term memory
|
||||
consolidates important information over time.
|
||||
"""
|
||||
|
||||
import os.path
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Generic, List, Optional, Tuple, Type
|
||||
|
||||
from dbgpt.core import Embeddings, LLMClient
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
Memory,
|
||||
SensoryMemory,
|
||||
ShortTermMemory,
|
||||
T,
|
||||
WriteOperation,
|
||||
)
|
||||
from .long_term import LongTermMemory
|
||||
from .short_term import EnhancedShortTermMemory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class HybridMemory(Memory, Generic[T]):
|
||||
"""Hybrid memory for the agent."""
|
||||
|
||||
importance_weight: float = 0.9
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
now: datetime,
|
||||
sensory_memory: SensoryMemory[T],
|
||||
short_term_memory: ShortTermMemory[T],
|
||||
long_term_memory: LongTermMemory[T],
|
||||
default_insight_extractor: Optional[InsightExtractor] = None,
|
||||
default_importance_scorer: Optional[ImportanceScorer] = None,
|
||||
):
|
||||
"""Create a hybrid memory."""
|
||||
self.now = now
|
||||
self._sensory_memory = sensory_memory
|
||||
self._short_term_memory = short_term_memory
|
||||
self._long_term_memory = long_term_memory
|
||||
self._default_insight_extractor = default_insight_extractor
|
||||
self._default_importance_scorer = default_importance_scorer
|
||||
|
||||
def structure_clone(
|
||||
self: "HybridMemory[T]", now: Optional[datetime] = None
|
||||
) -> "HybridMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
now = now or self.now
|
||||
m = HybridMemory(
|
||||
now=now,
|
||||
sensory_memory=self._sensory_memory.structure_clone(now),
|
||||
short_term_memory=self._short_term_memory.structure_clone(now),
|
||||
long_term_memory=self._long_term_memory.structure_clone(now),
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
def from_chroma(
|
||||
cls,
|
||||
vstore_name: Optional[str] = "_chroma_agent_memory_",
|
||||
vstore_path: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
now: Optional[datetime] = None,
|
||||
sensory_memory: Optional[SensoryMemory[T]] = None,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
long_term_memory: Optional[LongTermMemory[T]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a hybrid memory from Chroma vector store."""
|
||||
from dbgpt.configs.model_config import DATA_DIR
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
if not embeddings:
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = DefaultEmbeddingFactory.openai()
|
||||
|
||||
vstore_path = vstore_path or os.path.join(DATA_DIR, "agent_memory")
|
||||
|
||||
vector_store_connector = VectorStoreConnector.from_default(
|
||||
vector_store_type="Chroma",
|
||||
embedding_fn=embeddings,
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name=vstore_name,
|
||||
persist_path=vstore_path,
|
||||
),
|
||||
)
|
||||
return cls.from_vstore(
|
||||
vector_store_connector=vector_store_connector,
|
||||
embeddings=embeddings,
|
||||
executor=executor,
|
||||
now=now,
|
||||
sensory_memory=sensory_memory,
|
||||
short_term_memory=short_term_memory,
|
||||
long_term_memory=long_term_memory,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_vstore(
|
||||
cls,
|
||||
vector_store_connector: "VectorStoreConnector",
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
now: Optional[datetime] = None,
|
||||
sensory_memory: Optional[SensoryMemory[T]] = None,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
long_term_memory: Optional[LongTermMemory[T]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a hybrid memory from vector store."""
|
||||
if not embeddings:
|
||||
embeddings = vector_store_connector.current_embeddings
|
||||
if not executor:
|
||||
executor = ThreadPoolExecutor()
|
||||
if not now:
|
||||
now = datetime.now()
|
||||
|
||||
if not sensory_memory:
|
||||
sensory_memory = SensoryMemory()
|
||||
if not short_term_memory:
|
||||
if not embeddings:
|
||||
raise ValueError("embeddings is required.")
|
||||
short_term_memory = EnhancedShortTermMemory(embeddings, executor)
|
||||
if not long_term_memory:
|
||||
long_term_memory = LongTermMemory(
|
||||
executor,
|
||||
vector_store_connector,
|
||||
now=now,
|
||||
)
|
||||
return cls(now, sensory_memory, short_term_memory, long_term_memory, **kwargs)
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[T]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[T]] = None,
|
||||
real_memory_fragment_class: Optional[Type[T]] = None,
|
||||
) -> None:
|
||||
"""Initialize the memory.
|
||||
|
||||
It will initialize all the memories.
|
||||
"""
|
||||
memories = [
|
||||
self._sensory_memory,
|
||||
self._short_term_memory,
|
||||
self._long_term_memory,
|
||||
]
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"llm_client": llm_client,
|
||||
"importance_scorer": importance_scorer,
|
||||
"insight_extractor": insight_extractor,
|
||||
"real_memory_fragment_class": real_memory_fragment_class,
|
||||
}
|
||||
for memory in memories:
|
||||
memory.initialize(**kwargs)
|
||||
super().initialize(**kwargs)
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
# First write to sensory memory
|
||||
sen_discarded_memories = await self._sensory_memory.write(memory_fragment)
|
||||
if not sen_discarded_memories:
|
||||
return None
|
||||
short_term_discarded_memories = []
|
||||
discarded_memory_fragments = []
|
||||
discarded_insights = []
|
||||
for sen_memory in sen_discarded_memories.discarded_memory_fragments:
|
||||
# Write to short term memory
|
||||
short_discarded_memory = await self._short_term_memory.write(sen_memory)
|
||||
if short_discarded_memory:
|
||||
short_term_discarded_memories.append(short_discarded_memory)
|
||||
discarded_memory_fragments.extend(
|
||||
short_discarded_memory.discarded_memory_fragments
|
||||
)
|
||||
for insight in short_discarded_memory.discarded_insights:
|
||||
# Just keep the first insight
|
||||
discarded_insights.append(insight.insights[0])
|
||||
# Obtain the importance of insights
|
||||
insight_scores = await self.score_memory_importance(discarded_insights)
|
||||
# Get the importance of insights
|
||||
for i, ins in enumerate(discarded_insights):
|
||||
ins.update_importance(insight_scores[i])
|
||||
all_memories = discarded_memory_fragments + discarded_insights
|
||||
if self._long_term_memory:
|
||||
# Write to long term memory
|
||||
await self._long_term_memory.write_batch(all_memories, self.now)
|
||||
return None
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memories from the memory."""
|
||||
(
|
||||
retrieved_long_term_memories,
|
||||
short_term_discarded_memories,
|
||||
) = await self.fetch_memories(observation, self._short_term_memory)
|
||||
|
||||
await self.save_memories_after_retrieval(short_term_discarded_memories)
|
||||
return retrieved_long_term_memories
|
||||
|
||||
@immutable
|
||||
async def fetch_memories(
|
||||
self,
|
||||
observation: str,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
) -> Tuple[List[T], List[DiscardedMemoryFragments[T]]]:
|
||||
"""Fetch memories from long term memory.
|
||||
|
||||
If short_term_memory is provided, write the fetched memories to the short term
|
||||
memory.
|
||||
"""
|
||||
retrieved_long_term_memories = await self._long_term_memory.fetch_memories(
|
||||
observation
|
||||
)
|
||||
if not short_term_memory:
|
||||
return retrieved_long_term_memories, []
|
||||
short_term_discarded_memories: List[DiscardedMemoryFragments[T]] = []
|
||||
discarded_memory_fragments: List[T] = []
|
||||
for ltm in retrieved_long_term_memories:
|
||||
short_discarded_memory = await short_term_memory.write(
|
||||
ltm, op=WriteOperation.RETRIEVAL
|
||||
)
|
||||
if short_discarded_memory:
|
||||
short_term_discarded_memories.append(short_discarded_memory)
|
||||
discarded_memory_fragments.extend(
|
||||
short_discarded_memory.discarded_memory_fragments
|
||||
)
|
||||
for stm in short_term_memory.short_term_memories:
|
||||
retrieved_long_term_memories.append(
|
||||
stm.current_class.build_from(
|
||||
observation=stm.raw_observation,
|
||||
importance=stm.importance,
|
||||
)
|
||||
)
|
||||
return retrieved_long_term_memories, short_term_discarded_memories
|
||||
|
||||
async def save_memories_after_retrieval(
|
||||
self, fragments: List[DiscardedMemoryFragments[T]]
|
||||
):
|
||||
"""Save memories after retrieval."""
|
||||
discarded_memory_fragments = []
|
||||
discarded_memory_insights: List[T] = []
|
||||
for f in fragments:
|
||||
discarded_memory_fragments.extend(f.discarded_memory_fragments)
|
||||
for fi in f.discarded_insights:
|
||||
discarded_memory_insights.append(fi.insights[0])
|
||||
insights_importance = await self.score_memory_importance(
|
||||
discarded_memory_insights
|
||||
)
|
||||
for i, ins in enumerate(discarded_memory_insights):
|
||||
ins.update_importance(insights_importance[i])
|
||||
all_memories = discarded_memory_fragments + discarded_memory_insights
|
||||
await self._long_term_memory.write_batch(all_memories, self.now)
|
||||
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear the memory.
|
||||
|
||||
# TODO
|
||||
"""
|
||||
return []
|
174
dbgpt/agent/core/memory/llm.py
Normal file
174
dbgpt/agent/core/memory/llm.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""LLM Utility For Agent Memory."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
LLMClient,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
)
|
||||
|
||||
from .base import ImportanceScorer, InsightExtractor, InsightMemoryFragment, T
|
||||
|
||||
|
||||
class BaseLLMCaller(BaseModel):
|
||||
"""Base class for LLM caller."""
|
||||
|
||||
prompt: str = ""
|
||||
model: Optional[str] = None
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompt: Union[ChatPromptTemplate, str],
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Call LLM client to generate response.
|
||||
|
||||
Args:
|
||||
llm_client(LLMClient): LLM client
|
||||
prompt(ChatPromptTemplate): prompt
|
||||
**kwargs: other keyword arguments
|
||||
|
||||
Returns:
|
||||
str: response
|
||||
"""
|
||||
if not llm_client:
|
||||
raise ValueError("LLM client is required.")
|
||||
if isinstance(prompt, str):
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[HumanPromptTemplate.from_template(prompt)]
|
||||
)
|
||||
model = self.model
|
||||
if not model:
|
||||
model = await self.get_model(llm_client)
|
||||
prompt_kwargs = {}
|
||||
prompt_kwargs.update(kwargs)
|
||||
pass_kwargs = {
|
||||
k: v for k, v in prompt_kwargs.items() if k in prompt.input_variables
|
||||
}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
model_request = ModelRequest.build_request(model, messages=model_messages)
|
||||
model_output = await llm_client.generate(model_request)
|
||||
if not model_output.success:
|
||||
raise ValueError("Call LLM failed.")
|
||||
return model_output.text
|
||||
|
||||
async def get_model(self, llm_client: LLMClient) -> str:
|
||||
"""Get the model.
|
||||
|
||||
Args:
|
||||
llm_client(LLMClient): LLM client
|
||||
|
||||
Returns:
|
||||
str: model
|
||||
"""
|
||||
models = await llm_client.models()
|
||||
if not models:
|
||||
raise ValueError("No models available.")
|
||||
self.model = models[0].model
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def _parse_list(text: str) -> List[str]:
|
||||
"""Parse a newline-separated string into a list of strings.
|
||||
|
||||
1. First, split by newline
|
||||
2. Remove whitespace from each line
|
||||
"""
|
||||
lines = re.split(r"\n", text.strip())
|
||||
lines = [line for line in lines if line.strip()] # remove empty lines
|
||||
# Use regular expression to remove the numbers and dots at the beginning of
|
||||
# each line
|
||||
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
|
||||
|
||||
@staticmethod
|
||||
def _parse_number(text: str, importance_weight: Optional[float] = None) -> float:
|
||||
"""Parse a number from a string."""
|
||||
match = re.search(r"^\D*(\d+)", text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if importance_weight is not None:
|
||||
score = (score / 10) * importance_weight
|
||||
return score
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
class LLMInsightExtractor(BaseLLMCaller, InsightExtractor[T]):
|
||||
"""LLM Insight Extractor.
|
||||
|
||||
Get high-level insights from memories.
|
||||
"""
|
||||
|
||||
prompt: str = (
|
||||
"There are some memories: {content}\nCan you infer from the "
|
||||
"above memories the high-level insight for this person's character? The insight"
|
||||
" needs to be significantly different from the content and structure of the "
|
||||
"original memories.Respond in one sentence.\n\n"
|
||||
"Results:"
|
||||
)
|
||||
|
||||
async def extract_insights(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> InsightMemoryFragment[T]:
|
||||
"""Extract insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
InsightMemoryFragment: The insights of the memory fragment.
|
||||
"""
|
||||
insights_str: str = await self.call_llm(
|
||||
self.prompt, llm_client, content=memory_fragment.raw_observation
|
||||
)
|
||||
insights_list = self._parse_list(insights_str)
|
||||
return InsightMemoryFragment(memory_fragment, insights_list)
|
||||
|
||||
|
||||
class LLMImportanceScorer(BaseLLMCaller, ImportanceScorer[T]):
|
||||
"""LLM Importance Scorer.
|
||||
|
||||
Score the importance of memories.
|
||||
"""
|
||||
|
||||
prompt: str = (
|
||||
"Please give an importance score between 1 to 10 for the following "
|
||||
"observation. Higher score indicates the observation is more important. More "
|
||||
"rules that should be followed are:"
|
||||
"\n(1): Learning experience of a certain skill is important"
|
||||
"\n(2): The occurrence of a particular event is important"
|
||||
"\n(3): User thoughts and emotions matter"
|
||||
"\n(4): More informative indicates more important."
|
||||
"Please respond with a single integer."
|
||||
"\nObservation:{content}"
|
||||
"\nRating:"
|
||||
)
|
||||
|
||||
async def score_importance(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> float:
|
||||
"""Score the importance of memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
float: The importance score of the memory fragment.
|
||||
"""
|
||||
score: str = await self.call_llm(
|
||||
self.prompt, llm_client, content=memory_fragment.raw_observation
|
||||
)
|
||||
return self._parse_number(score)
|
192
dbgpt/agent/core/memory/long_term.py
Normal file
192
dbgpt/agent/core/memory/long_term.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Long-term memory module."""
|
||||
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from typing import Generic, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.time_weighted import TimeWeightedEmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
from .base import DiscardedMemoryFragments, Memory, T, WriteOperation
|
||||
|
||||
_FORGET_PLACEHOLDER = "[FORGET]"
|
||||
_MERGE_PLACEHOLDER = "[MERGE]"
|
||||
_METADATA_BUFFER_IDX = "buffer_idx"
|
||||
_METADATA_LAST_ACCESSED_AT = "last_accessed_at"
|
||||
_METADAT_IMPORTANCE = "importance"
|
||||
|
||||
|
||||
class LongTermRetriever(TimeWeightedEmbeddingRetriever):
|
||||
"""Long-term retriever."""
|
||||
|
||||
def __init__(self, now: datetime, **kwargs):
|
||||
"""Create a long-term retriever."""
|
||||
self.now = now
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@mutable
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve memories."""
|
||||
current_time = self.now
|
||||
docs_and_scores = {
|
||||
doc.metadata[_METADATA_BUFFER_IDX]: (doc, self.default_salience)
|
||||
# Calculate for all memories.
|
||||
for doc in self.memory_stream
|
||||
}
|
||||
# If a doc is considered salient, update the salience score
|
||||
docs_and_scores.update(self.get_salient_docs(query))
|
||||
rescored_docs = [
|
||||
(doc, self._get_combined_score(doc, relevance, current_time))
|
||||
for doc, relevance in docs_and_scores.values()
|
||||
]
|
||||
rescored_docs.sort(key=lambda x: x[1], reverse=True)
|
||||
result = []
|
||||
# Ensure frequently accessed memories aren't forgotten
|
||||
retrieved_num = 0
|
||||
for doc, _ in rescored_docs:
|
||||
if (
|
||||
retrieved_num < self._k
|
||||
and doc.content.find(_FORGET_PLACEHOLDER) == -1
|
||||
and doc.content.find(_MERGE_PLACEHOLDER) == -1
|
||||
):
|
||||
retrieved_num += 1
|
||||
buffered_doc = self.memory_stream[doc.metadata[_METADATA_BUFFER_IDX]]
|
||||
buffered_doc.metadata[_METADATA_LAST_ACCESSED_AT] = current_time
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
|
||||
class LongTermMemory(Memory, Generic[T]):
|
||||
"""Long-term memory."""
|
||||
|
||||
importance_weight: float = 0.15
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: Executor,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
now: Optional[datetime] = None,
|
||||
reflection_threshold: Optional[float] = None,
|
||||
):
|
||||
"""Create a long-term memory."""
|
||||
self.now = now or datetime.now()
|
||||
self.executor = executor
|
||||
self.reflecting: bool = False
|
||||
self.forgetting: bool = False
|
||||
self.reflection_threshold: Optional[float] = reflection_threshold
|
||||
self.aggregate_importance: float = 0.0
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self.memory_retriever = LongTermRetriever(
|
||||
now=self.now, vector_store_connector=vector_store_connector
|
||||
)
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "LongTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "LongTermMemory[T]":
|
||||
"""Create a structure clone of the long-term memory."""
|
||||
new_name = self.name
|
||||
if not new_name:
|
||||
raise ValueError("name is required.")
|
||||
m: LongTermMemory[T] = LongTermMemory(
|
||||
now=now,
|
||||
executor=self.executor,
|
||||
vector_store_connector=self._vector_store_connector.new_connector(new_name),
|
||||
reflection_threshold=self.reflection_threshold,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
importance = memory_fragment.importance
|
||||
last_accessed_time = memory_fragment.last_accessed_time
|
||||
if importance is None:
|
||||
raise ValueError("importance is required.")
|
||||
if not self.reflecting:
|
||||
self.aggregate_importance += importance
|
||||
|
||||
memory_idx = len(self.memory_retriever.memory_stream)
|
||||
document = Chunk(
|
||||
page_content="[{}] ".format(memory_idx)
|
||||
+ str(memory_fragment.raw_observation),
|
||||
metadata={
|
||||
_METADAT_IMPORTANCE: importance,
|
||||
_METADATA_LAST_ACCESSED_AT: last_accessed_time,
|
||||
},
|
||||
)
|
||||
await blocking_func_to_async(
|
||||
self.executor,
|
||||
self.memory_retriever.load_document,
|
||||
[document],
|
||||
current_time=now,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@mutable
|
||||
async def write_batch(
|
||||
self, memory_fragments: List[T], now: Optional[datetime] = None
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a batch of memory fragments to the memory."""
|
||||
current_datetime = self.now
|
||||
if not now:
|
||||
raise ValueError("Now time is required.")
|
||||
for short_term_memory in memory_fragments:
|
||||
short_term_memory.update_accessed_time(now=now)
|
||||
await self.write(short_term_memory, now=current_datetime)
|
||||
# TODO(fangyinc): Reflect on the memories and get high-level insights.
|
||||
# TODO(fangyinc): Forget memories that are not important.
|
||||
return None
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments related to the observation."""
|
||||
return await self.fetch_memories(observation=observation, now=self.now)
|
||||
|
||||
@immutable
|
||||
async def fetch_memories(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> List[T]:
|
||||
"""Fetch memories related to the observation."""
|
||||
# TODO: Mock now?
|
||||
retrieved_memories = []
|
||||
retrieved_list = await blocking_func_to_async(
|
||||
self.executor,
|
||||
self.memory_retriever.retrieve,
|
||||
observation,
|
||||
)
|
||||
for retrieved_chunk in retrieved_list:
|
||||
retrieved_memories.append(
|
||||
self.real_memory_fragment_class.build_from(
|
||||
observation=retrieved_chunk.content,
|
||||
importance=retrieved_chunk.metadata[_METADAT_IMPORTANCE],
|
||||
)
|
||||
)
|
||||
return retrieved_memories
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear the memory.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
return []
|
203
dbgpt/agent/core/memory/short_term.py
Normal file
203
dbgpt/agent/core/memory/short_term.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Short term memory module."""
|
||||
|
||||
import random
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.similarity_util import cosine_similarity, sigmoid_function
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
InsightMemoryFragment,
|
||||
ShortTermMemory,
|
||||
T,
|
||||
WriteOperation,
|
||||
)
|
||||
|
||||
|
||||
class EnhancedShortTermMemory(ShortTermMemory[T]):
|
||||
"""Enhanced short term memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
executor: Executor,
|
||||
buffer_size: int = 2,
|
||||
enhance_similarity_threshold: float = 0.7,
|
||||
enhance_threshold: int = 3,
|
||||
):
|
||||
"""Initialize enhanced short term memory."""
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self._executor = executor
|
||||
self._embeddings = embeddings
|
||||
self.short_embeddings: List[List[float]] = []
|
||||
self.enhance_cnt: List[int] = [0 for _ in range(self._buffer_size)]
|
||||
self.enhance_memories: List[List[T]] = [[] for _ in range(self._buffer_size)]
|
||||
self.enhance_similarity_threshold = enhance_similarity_threshold
|
||||
self.enhance_threshold = enhance_threshold
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "EnhancedShortTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "EnhancedShortTermMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: EnhancedShortTermMemory[T] = EnhancedShortTermMemory(
|
||||
embeddings=self._embeddings,
|
||||
executor=self._executor,
|
||||
buffer_size=self._buffer_size,
|
||||
enhance_similarity_threshold=self.enhance_similarity_threshold,
|
||||
enhance_threshold=self.enhance_threshold,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write memory fragment to short term memory.
|
||||
|
||||
Reference: https://github.com/RUC-GSAI/YuLan-Rec/blob/main/agents/recagent_memory.py#L336 # noqa
|
||||
"""
|
||||
# Calculate current embeddings of current memory fragment
|
||||
memory_fragment_embeddings = await blocking_func_to_async(
|
||||
self._executor,
|
||||
memory_fragment.calculate_current_embeddings,
|
||||
self._embeddings.embed_documents,
|
||||
)
|
||||
memory_fragment.update_embeddings(memory_fragment_embeddings)
|
||||
for idx, memory_embedding in enumerate(self.short_embeddings):
|
||||
similarity = await blocking_func_to_async(
|
||||
self._executor,
|
||||
cosine_similarity,
|
||||
memory_embedding,
|
||||
memory_fragment_embeddings,
|
||||
)
|
||||
# Sigmoid probability, transform similarity to [0, 1]
|
||||
sigmoid_prob: float = await blocking_func_to_async(
|
||||
self._executor, sigmoid_function, similarity
|
||||
)
|
||||
if (
|
||||
sigmoid_prob >= self.enhance_similarity_threshold
|
||||
and random.random() < sigmoid_prob
|
||||
):
|
||||
self.enhance_cnt[idx] += 1
|
||||
self.enhance_memories[idx].append(memory_fragment)
|
||||
discard_memories = await self.transfer_to_long_term(memory_fragment)
|
||||
if op == WriteOperation.ADD:
|
||||
self._fragments.append(memory_fragment)
|
||||
self.short_embeddings.append(memory_fragment_embeddings)
|
||||
await self.handle_overflow(self._fragments)
|
||||
return discard_memories
|
||||
|
||||
@mutable
|
||||
async def transfer_to_long_term(
|
||||
self, memory_fragment: T
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Transfer memory fragment to long term memory."""
|
||||
transfer_flag = False
|
||||
existing_memory = [True for _ in range(len(self.short_term_memories))]
|
||||
|
||||
enhance_memories: List[T] = []
|
||||
to_get_insight_memories: List[T] = []
|
||||
for idx, memory in enumerate(self.short_term_memories):
|
||||
# if exceed the enhancement threshold
|
||||
if (
|
||||
self.enhance_cnt[idx] >= self.enhance_threshold
|
||||
and existing_memory[idx] is True
|
||||
):
|
||||
existing_memory[idx] = False
|
||||
transfer_flag = True
|
||||
#
|
||||
# short-term memories
|
||||
content = [memory]
|
||||
# do not repeatedly add observation memory to summary, so use [:-1].
|
||||
for enhance_memory in self.enhance_memories[idx][:-1]:
|
||||
content.append(enhance_memory)
|
||||
content.append(memory_fragment)
|
||||
# Merge the enhanced memories to single memory
|
||||
merged_enhance_memory: T = memory.reduce(
|
||||
content, merged_memory=memory.importance
|
||||
)
|
||||
to_get_insight_memories.append(merged_enhance_memory)
|
||||
enhance_memories.append(merged_enhance_memory)
|
||||
# Get insights for the every enhanced memory
|
||||
enhance_insights: List[InsightMemoryFragment] = await self.get_insights(
|
||||
to_get_insight_memories
|
||||
)
|
||||
|
||||
if transfer_flag:
|
||||
# re-construct the indexes of short-term memories after removing summarized
|
||||
# memories
|
||||
new_memories: List[T] = []
|
||||
new_embeddings: List[List[float]] = []
|
||||
new_enhance_memories: List[List[T]] = [[] for _ in range(self._buffer_size)]
|
||||
new_enhance_cnt: List[int] = [0 for _ in range(self._buffer_size)]
|
||||
for idx, memory in enumerate(self.short_term_memories):
|
||||
if existing_memory[idx]:
|
||||
# Remove not enhanced memories to new memories
|
||||
new_enhance_memories[len(new_memories)] = self.enhance_memories[idx]
|
||||
new_enhance_cnt[len(new_memories)] = self.enhance_cnt[idx]
|
||||
new_memories.append(memory)
|
||||
new_embeddings.append(self.short_embeddings[idx])
|
||||
self._fragments = new_memories
|
||||
self.short_embeddings = new_embeddings
|
||||
self.enhance_memories = new_enhance_memories
|
||||
self.enhance_cnt = new_enhance_cnt
|
||||
return DiscardedMemoryFragments(enhance_memories, enhance_insights)
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle overflow of short term memory.
|
||||
|
||||
Discard the least important memory fragment if the buffer size exceeds.
|
||||
"""
|
||||
if len(self.short_term_memories) > self._buffer_size:
|
||||
id2fragments: Dict[int, Dict] = {}
|
||||
for idx in range(len(self.short_term_memories) - 1):
|
||||
# Not discard the last one
|
||||
memory = self.short_term_memories[idx]
|
||||
id2fragments[idx] = {
|
||||
"enhance_count": self.enhance_cnt[idx],
|
||||
"importance": memory.importance,
|
||||
}
|
||||
# Sort by importance and enhance count, first discard the least important
|
||||
sorted_ids = sorted(
|
||||
id2fragments.keys(),
|
||||
key=lambda x: (
|
||||
id2fragments[x]["importance"],
|
||||
id2fragments[x]["enhance_count"],
|
||||
),
|
||||
)
|
||||
pop_id = sorted_ids[0]
|
||||
pop_raw_observation = self.short_term_memories[pop_id].raw_observation
|
||||
self.enhance_cnt.pop(pop_id)
|
||||
self.enhance_cnt.append(0)
|
||||
self.enhance_memories.pop(pop_id)
|
||||
self.enhance_memories.append([])
|
||||
|
||||
discard_memory = self._fragments.pop(pop_id)
|
||||
self.short_embeddings.pop(pop_id)
|
||||
|
||||
# remove the discard_memory from other short-term memory's enhanced list
|
||||
for idx in range(len(self.short_term_memories)):
|
||||
current_enhance_memories: List[T] = self.enhance_memories[idx]
|
||||
to_remove_idx = []
|
||||
for i, ehf in enumerate(current_enhance_memories):
|
||||
if ehf.raw_observation == pop_raw_observation:
|
||||
to_remove_idx.append(i)
|
||||
for i in to_remove_idx:
|
||||
current_enhance_memories.pop(i)
|
||||
self.enhance_cnt[idx] -= len(to_remove_idx)
|
||||
|
||||
return memory_fragments, [discard_memory]
|
||||
return memory_fragments, []
|
@@ -1,4 +1,5 @@
|
||||
"""Agent Operator for AWEL."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Type
|
||||
|
||||
@@ -16,10 +17,10 @@ from dbgpt.core.interface.message import ModelMessageRoleType
|
||||
# TODO: Don't dependent on MixinLLMOperator
|
||||
from dbgpt.model.operators.llm_operator import MixinLLMOperator
|
||||
|
||||
from ...core.agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...core.agent_manage import get_agent_manager
|
||||
from ...core.base_agent import ConversableAgent
|
||||
from ...core.llm.llm import LLMConfig
|
||||
from ....util.llm.llm import LLMConfig
|
||||
from ...agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...agent_manage import get_agent_manager
|
||||
from ...base_agent import ConversableAgent
|
||||
from .agent_operator_resource import AWELAgent
|
||||
|
||||
|
||||
@@ -61,9 +62,7 @@ class WrappedAgentOperator(
|
||||
input_message = input_value.message.copy()
|
||||
|
||||
# Isolate the message delivery mechanism and pass it to the operator
|
||||
_goal = (
|
||||
self.agent.get_name() if self.agent.get_name() else self.agent.get_profile()
|
||||
)
|
||||
_goal = self.agent.name if self.agent.name else self.agent.role
|
||||
current_goal = f"[{_goal}]:"
|
||||
|
||||
if input_message.content:
|
||||
@@ -95,7 +94,7 @@ class WrappedAgentOperator(
|
||||
|
||||
if not is_success:
|
||||
raise ValueError(
|
||||
f"The task failed at step {self.agent.get_profile()} and the attempt "
|
||||
f"The task failed at step {self.agent.role} and the attempt "
|
||||
f"to repair it failed. The final reason for "
|
||||
f"failure:{agent_reply_message.content}!"
|
||||
)
|
||||
@@ -170,18 +169,14 @@ class AWELAgentOperator(
|
||||
agent = await self.get_agent(input_value)
|
||||
if agent.fixed_subgoal and len(agent.fixed_subgoal) > 0:
|
||||
# Isolate the message delivery mechanism and pass it to the operator
|
||||
current_goal = (
|
||||
f"[{agent.get_name() if agent.get_name() else agent.get_profile()}]:"
|
||||
)
|
||||
current_goal = f"[{agent.name if agent.name else agent.role}]:"
|
||||
if agent.fixed_subgoal:
|
||||
current_goal += agent.fixed_subgoal
|
||||
input_message.current_goal = current_goal
|
||||
input_message.content = agent.fixed_subgoal
|
||||
else:
|
||||
# Isolate the message delivery mechanism and pass it to the operator
|
||||
current_goal = (
|
||||
f"[{agent.get_name() if agent.get_name() else agent.get_profile()}]:"
|
||||
)
|
||||
current_goal = f"[{agent.name if agent.name else agent.role}]:"
|
||||
if input_message.content:
|
||||
current_goal += input_message.content
|
||||
input_message.current_goal = current_goal
|
||||
@@ -213,7 +208,7 @@ class AWELAgentOperator(
|
||||
|
||||
if not is_success:
|
||||
raise ValueError(
|
||||
f"The task failed at step {agent.get_profile()} and the attempt to "
|
||||
f"The task failed at step {agent.role} and the attempt to "
|
||||
f"repair it failed. The final reason for "
|
||||
f"failure:{agent_reply_message.content}!"
|
||||
)
|
||||
@@ -231,7 +226,7 @@ class AWELAgentOperator(
|
||||
# Default single step transfer of information
|
||||
rely_messages=now_rely_messages,
|
||||
silent=input_value.silent,
|
||||
memory=input_value.memory,
|
||||
memory=input_value.memory.structure_clone() if input_value.memory else None,
|
||||
agent_context=input_value.agent_context,
|
||||
resource_loader=input_value.resource_loader,
|
||||
llm_client=input_value.llm_client,
|
@@ -1,4 +1,5 @@
|
||||
"""The AWEL Agent Operator Resource."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
@@ -11,9 +12,9 @@ from dbgpt.core.awel.flow import (
|
||||
register_resource,
|
||||
)
|
||||
|
||||
from ...core.agent_manage import get_agent_manager
|
||||
from ...core.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ....resource.resource_api import AgentResource, ResourceType
|
||||
from ....util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ...agent_manage import get_agent_manager
|
||||
|
||||
|
||||
@register_resource(
|
||||
@@ -50,7 +51,10 @@ from ...resource.resource_api import AgentResource, ResourceType
|
||||
description="The agent resource value.",
|
||||
),
|
||||
],
|
||||
alias=["dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgentResource"],
|
||||
alias=[
|
||||
"dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgentResource",
|
||||
"dbgpt.agent.plan.awel.agent_operator_resource.AWELAgentResource",
|
||||
],
|
||||
)
|
||||
class AWELAgentResource(AgentResource):
|
||||
"""AWEL Agent Resource."""
|
||||
@@ -107,7 +111,10 @@ class AWELAgentResource(AgentResource):
|
||||
description="The agent LLM Strategy Value.",
|
||||
),
|
||||
],
|
||||
alias=["dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgentConfig"],
|
||||
alias=[
|
||||
"dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgentConfig",
|
||||
"dbgpt.agent.plan.awel.agent_operator_resource.AWELAgentConfig",
|
||||
],
|
||||
)
|
||||
class AWELAgentConfig(LLMConfig):
|
||||
"""AWEL Agent Config."""
|
||||
@@ -168,7 +175,10 @@ def _agent_resource_option_values() -> List[OptionValue]:
|
||||
description="The agent llm config.",
|
||||
),
|
||||
],
|
||||
alias=["dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgent"],
|
||||
alias=[
|
||||
"dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgent",
|
||||
"dbgpt.agent.plan.awel.agent_operator_resource.AWELAgent",
|
||||
],
|
||||
)
|
||||
class AWELAgent(BaseModel):
|
||||
"""AWEL Agent."""
|
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import (
|
||||
@@ -15,9 +15,10 @@ from dbgpt._private.pydantic import (
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
|
||||
from ...actions.action import ActionOutput
|
||||
from ...core.agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...core.base_team import ManagerAgent
|
||||
from ...action.base import ActionOutput
|
||||
from ...agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...base_team import ManagerAgent
|
||||
from ...profile import DynConfig, ProfileConfig
|
||||
from .agent_operator import AWELAgentOperator, WrappedAgentOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -84,11 +85,24 @@ class AWELBaseManager(ManagerAgent, ABC):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
goal: str = (
|
||||
"Promote and solve user problems according to the process arranged by AWEL."
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="AWELBaseManager",
|
||||
role=DynConfig(
|
||||
"PlanManager", category="agent", key="dbgpt_agent_plan_awel_profile_name"
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Promote and solve user problems according to the process arranged "
|
||||
"by AWEL.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_awel_profile_goal",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"Promote and solve user problems according to the process arranged "
|
||||
"by AWEL.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_awel_profile_desc",
|
||||
),
|
||||
)
|
||||
constraints: List[str] = []
|
||||
desc: str = goal
|
||||
|
||||
async def _a_process_received_message(self, message: AgentMessage, sender: Agent):
|
||||
"""Process the received message."""
|
||||
@@ -116,7 +130,7 @@ class AWELBaseManager(ManagerAgent, ABC):
|
||||
message=AgentMessage(content=message, current_goal=message),
|
||||
sender=sender,
|
||||
reviewer=reviewer,
|
||||
memory=self.memory,
|
||||
memory=self.memory.structure_clone(),
|
||||
agent_context=self.agent_context,
|
||||
resource_loader=self.resource_loader,
|
||||
llm_client=self.not_null_llm_config.llm_client,
|
||||
@@ -162,8 +176,6 @@ class WrappedAWELLayoutManager(AWELBaseManager):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = "WrappedAWELLayoutManager"
|
||||
|
||||
dag: Optional[DAG] = Field(None, description="The DAG of the manager")
|
||||
|
||||
def get_dag(self) -> DAG:
|
||||
@@ -238,8 +250,6 @@ class DefaultAWELLayoutManager(AWELBaseManager):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = "DefaultAWELLayoutManager"
|
||||
|
||||
dag: AWELTeamContext = Field(...)
|
||||
|
||||
@validator("dag")
|
@@ -6,12 +6,12 @@ from typing import List, Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.vis.tags.vis_agent_plans import Vis, VisAgentPlans
|
||||
|
||||
from ..actions.action import Action, ActionOutput
|
||||
from ..core.agent import AgentContext
|
||||
from ..core.schema import Status
|
||||
from ..memory.base import GptsPlan
|
||||
from ..memory.gpts_memory import GptsPlansMemory
|
||||
from ..resource.resource_api import AgentResource
|
||||
from ...resource.resource_api import AgentResource
|
||||
from ..action.base import Action, ActionOutput
|
||||
from ..agent import AgentContext
|
||||
from ..memory.gpts.base import GptsPlan
|
||||
from ..memory.gpts.gpts_memory import GptsPlansMemory
|
||||
from ..schema import Status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
165
dbgpt/agent/core/plan/planner_agent.py
Normal file
165
dbgpt/agent/core/plan/planner_agent.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Planner Agent."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
|
||||
from ..agent import AgentMessage
|
||||
from ..base_agent import ConversableAgent
|
||||
from ..plan.plan_action import PlanAction
|
||||
from ..profile import DynConfig, ProfileConfig
|
||||
|
||||
|
||||
class PlannerAgent(ConversableAgent):
|
||||
"""Planner Agent.
|
||||
|
||||
Planner agent, realizing task goal planning decomposition through LLM.
|
||||
"""
|
||||
|
||||
agents: List[ConversableAgent] = Field(default_factory=list)
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Planner",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"Planner",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Understand each of the following intelligent agents and their "
|
||||
"capabilities, using the provided resources, solve user problems by "
|
||||
"coordinating intelligent agents. Please utilize your LLM's knowledge "
|
||||
"and understanding ability to comprehend the intent and goals of the "
|
||||
"user's problem, generating a task plan that can be completed through"
|
||||
" the collaboration of intelligent agents without user assistance.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_goal",
|
||||
),
|
||||
expand_prompt=DynConfig(
|
||||
"Available Intelligent Agents:\n {{ agents }}",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_expand_prompt",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Every step of the task plan should exist to advance towards solving "
|
||||
"the user's goals. Do not generate meaningless task steps; ensure "
|
||||
"that each step has a clear goal and its content is complete.",
|
||||
"Pay attention to the dependencies and logic of each step in the task "
|
||||
"plan. For the steps that are depended upon, consider the data they "
|
||||
"depend on and whether it can be obtained based on the current goal. "
|
||||
"If it cannot be obtained, please indicate in the goal that the "
|
||||
"dependent data needs to be generated.",
|
||||
"Each step must be an independently achievable goal. Ensure that the "
|
||||
"logic and information are complete. Avoid steps with unclear "
|
||||
"objectives, like 'Analyze the retrieved issues data,' where it's "
|
||||
"unclear what specific content needs to be analyzed.",
|
||||
"Please ensure that only the intelligent agents mentioned above are "
|
||||
"used, and you may use only the necessary parts of them. Allocate "
|
||||
"them to appropriate steps strictly based on their described "
|
||||
"capabilities and limitations. Each intelligent agent can be reused.",
|
||||
"Utilize the provided resources to assist in generating the plan "
|
||||
"steps according to the actual needs of the user's goals. Do not use "
|
||||
"unnecessary resources.",
|
||||
"Each step should ideally use only one type of resource to accomplish "
|
||||
"a sub-goal. If the current goal can be broken down into multiple "
|
||||
"subtasks of the same type, you can create mutually independent "
|
||||
"parallel tasks.",
|
||||
"Data resources can be loaded and utilized by the appropriate "
|
||||
"intelligent agents without the need to consider the issues related "
|
||||
"to data loading links.",
|
||||
"Try to merge continuous steps that have sequential dependencies. If "
|
||||
"the user's goal does not require splitting, you can create a "
|
||||
"single-step task with content that is the user's goal.",
|
||||
"Carefully review the plan to ensure it comprehensively covers all "
|
||||
"information involved in the user's problem and can ultimately "
|
||||
"achieve the goal. Confirm whether each step includes the necessary "
|
||||
"resource information, such as URLs, resource names, etc.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"You are a task planning expert! You can coordinate intelligent agents"
|
||||
" and allocate resources to achieve complex task goals.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_desc",
|
||||
),
|
||||
examples=DynConfig(
|
||||
"""
|
||||
user:help me build a sales report summarizing our key metrics and trends
|
||||
assistants:[
|
||||
{{
|
||||
"serial_number": "1",
|
||||
"agent": "DataScientist",
|
||||
"content": "Retrieve total sales, average sales, and number of transactions grouped by "product_category"'.",
|
||||
"rely": ""
|
||||
}},
|
||||
{{
|
||||
"serial_number": "2",
|
||||
"agent": "DataScientist",
|
||||
"content": "Retrieve monthly sales and transaction number trends.",
|
||||
"rely": ""
|
||||
}},
|
||||
{{
|
||||
"serial_number": "3",
|
||||
"agent": "Reporter",
|
||||
"content": "Integrate analytical data into the format required to build sales reports.",
|
||||
"rely": "1,2"
|
||||
}}
|
||||
]""", # noqa: E501
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_examples",
|
||||
),
|
||||
)
|
||||
_goal_zh: str = (
|
||||
"理解下面每个智能体(agent)和他们的能力,使用给出的资源,通过协调智能体来解决"
|
||||
"用户问题。 请发挥你LLM的知识和理解能力,理解用户问题的意图和目标,生成一个可以在没有用户帮助"
|
||||
"下,由智能体协作完成目标的任务计划。"
|
||||
)
|
||||
_expand_prompt_zh: str = "可用智能体(agent):\n {{ agents }}"
|
||||
|
||||
_constraints_zh: List[str] = [
|
||||
"任务计划的每个步骤都应该是为了推进解决用户目标而存在,不要生成无意义的任务步骤,确保每个步骤内目标明确内容完整。",
|
||||
"关注任务计划每个步骤的依赖关系和逻辑,被依赖步骤要考虑被依赖的数据,是否能基于当前目标得到,如果不能请在目标中提示要生成被依赖数据。",
|
||||
"每个步骤都是一个独立可完成的目标,一定要确保逻辑和信息完整,不要出现类似:"
|
||||
"'Analyze the retrieved issues data'这样目标不明确,不知道具体要分析啥内容的步骤",
|
||||
"请确保只使用上面提到的智能体,并且可以只使用其中需要的部分,严格根据描述能力和限制分配给合适的步骤,每个智能体都可以重复使用。",
|
||||
"根据用户目标的实际需要使用提供的资源来协助生成计划步骤,不要使用不需要的资源。",
|
||||
"每个步骤最好只使用一种资源完成一个子目标,如果当前目标可以分解为同类型的多个子任务,可以生成相互不依赖的并行任务。",
|
||||
"数据资源可以被合适的智能体加载使用,不用考虑数据资源的加载链接问题",
|
||||
"尽量合并有顺序依赖的连续相同步骤,如果用户目标无拆分必要,可以生成内容为用户目标的单步任务。",
|
||||
"仔细检查计划,确保计划完整的包含了用户问题所涉及的所有信息,并且最终能完成目标,确认每个步骤是否包含了需要用到的资源信息,如URL、资源名等. ",
|
||||
]
|
||||
_desc_zh: str = "你是一个任务规划专家!可以协调智能体,分配资源完成复杂的任务目标。"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new PlannerAgent instance."""
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([PlanAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage):
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
reply_message.context = {
|
||||
"agents": "\n".join([f"- {item.role}:{item.desc}" for item in self.agents]),
|
||||
}
|
||||
return reply_message
|
||||
|
||||
def bind_agents(self, agents: List[ConversableAgent]) -> ConversableAgent:
|
||||
"""Bind the agents to the planner agent."""
|
||||
self.agents = agents
|
||||
for agent in self.agents:
|
||||
if agent.resources and len(agent.resources) > 0:
|
||||
self.resources.extend(agent.resources)
|
||||
return self
|
||||
|
||||
def prepare_act_param(self) -> Dict[str, Any]:
|
||||
"""Prepare the parameters for the act method."""
|
||||
return {
|
||||
"context": self.not_null_agent_context,
|
||||
"plans_memory": self.memory.plans_memory,
|
||||
}
|
@@ -5,14 +5,15 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessageRoleType
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from ..core.agent import Agent, AgentMessage
|
||||
from ..core.agent_manage import mentioned_agents, participant_roles
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.base_team import ManagerAgent
|
||||
from ..core.schema import Status
|
||||
from ..memory.base import GptsPlan
|
||||
from .planner_agent import PlannerAgent
|
||||
from ..action.base import ActionOutput
|
||||
from ..agent import Agent, AgentMessage
|
||||
from ..agent_manage import mentioned_agents, participant_roles
|
||||
from ..base_agent import ConversableAgent
|
||||
from ..base_team import ManagerAgent
|
||||
from ..memory.gpts.base import GptsPlan
|
||||
from ..plan.planner_agent import PlannerAgent
|
||||
from ..profile import DynConfig, ProfileConfig
|
||||
from ..schema import Status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,14 +21,30 @@ logger = logging.getLogger(__name__)
|
||||
class AutoPlanChatManager(ManagerAgent):
|
||||
"""A chat manager agent that can manage a team chat of multiple agents."""
|
||||
|
||||
profile: str = "PlanManager"
|
||||
goal: str = (
|
||||
"Advance the task plan generated by the planning agent. If the plan "
|
||||
"does not pre-allocate an agent, it needs to be coordinated with the "
|
||||
"appropriate agent to complete."
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"AutoPlanChatManager",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"PlanManager",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Advance the task plan generated by the planning agent. If the plan "
|
||||
"does not pre-allocate an agent, it needs to be coordinated with the "
|
||||
"appropriate agent to complete.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_goal",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"Advance the task plan generated by the planning agent.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_desc",
|
||||
),
|
||||
)
|
||||
constraints: List[str] = []
|
||||
desc: str = "Advance the task plan generated by the planning agent."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new AutoPlanChatManager instance."""
|
||||
@@ -56,19 +73,21 @@ class AutoPlanChatManager(ManagerAgent):
|
||||
{
|
||||
"content": rely_task.sub_task_content,
|
||||
"role": ModelMessageRoleType.HUMAN,
|
||||
"name": rely_task.sub_task_agent,
|
||||
}
|
||||
)
|
||||
rely_messages.append(
|
||||
{
|
||||
"content": rely_task.result,
|
||||
"role": ModelMessageRoleType.AI,
|
||||
"name": rely_task.sub_task_agent,
|
||||
}
|
||||
)
|
||||
return rely_prompt, rely_messages
|
||||
|
||||
def select_speaker_msg(self, agents: List[Agent]) -> str:
|
||||
"""Return the message for selecting the next speaker."""
|
||||
agent_names = [agent.get_name() for agent in agents]
|
||||
agent_names = [agent.name for agent in agents]
|
||||
return (
|
||||
"You are in a role play game. The following roles are available:\n"
|
||||
f" {participant_roles(agents)}.\n"
|
||||
@@ -95,7 +114,7 @@ class AutoPlanChatManager(ManagerAgent):
|
||||
else:
|
||||
# auto speaker selection
|
||||
# TODO selector a_thinking It has been overwritten and cannot be used.
|
||||
agent_names = [agent.get_name() for agent in agents]
|
||||
agent_names = [agent.name for agent in agents]
|
||||
fina_name, model = await selector.thinking(
|
||||
messages=[
|
||||
AgentMessage(
|
||||
@@ -104,7 +123,7 @@ class AutoPlanChatManager(ManagerAgent):
|
||||
" assign the appropriate role to complete the task.\n"
|
||||
f"Task content: {now_goal_context},\n"
|
||||
f"Select the role from: {agent_names},\n"
|
||||
f"Please only return the role, such as: {agents[0].get_name()}",
|
||||
f"Please only return the role, such as: {agents[0].name}",
|
||||
)
|
||||
],
|
||||
prompt=self.select_speaker_msg(agents),
|
||||
@@ -269,7 +288,7 @@ class AutoPlanChatManager(ManagerAgent):
|
||||
now_plan.sub_task_num,
|
||||
Status.FAILED.value,
|
||||
now_plan.retry_times + 1,
|
||||
speaker.get_name(),
|
||||
speaker.name,
|
||||
"",
|
||||
plan_result,
|
||||
)
|
31
dbgpt/agent/core/profile/__init__.py
Normal file
31
dbgpt/agent/core/profile/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Profiling module.
|
||||
|
||||
Autonomous agents typically perform tasks by assuming specific roles, such as coders,
|
||||
teachers and domain experts.
|
||||
|
||||
The profiling module aims to indicate the profiles of the agent roles, which are usually
|
||||
written into the prompt to influence the LLM behaviors.
|
||||
|
||||
Agent profiles typically encompass basic information such as age, gender, and career,
|
||||
as well as psychology information, reflecting the personalities of the agent, and social
|
||||
information, detailing the relationships between agents.
|
||||
|
||||
The choice of analysis information depends heavily on the application scenario.
|
||||
|
||||
How to create a profile:
|
||||
1. Handcrafting method
|
||||
2. LLM-generation method
|
||||
3. Dataset alignment method
|
||||
"""
|
||||
|
||||
from dbgpt.util.configure import DynConfig # noqa: F401
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
CompositeProfileFactory,
|
||||
DatasetProfileFactory,
|
||||
DefaultProfile,
|
||||
LLMProfileFactory,
|
||||
Profile,
|
||||
ProfileConfig,
|
||||
ProfileFactory,
|
||||
)
|
413
dbgpt/agent/core/profile/base.py
Normal file
413
dbgpt/agent/core/profile/base.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""Profile module."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import cachetools
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.util.configure import ConfigInfo, DynConfig
|
||||
|
||||
VALID_TEMPLATE_KEYS = {
|
||||
"role",
|
||||
"name",
|
||||
"goal",
|
||||
"resource_prompt",
|
||||
"expand_prompt",
|
||||
"language",
|
||||
"constraints",
|
||||
"examples",
|
||||
"out_schema",
|
||||
"most_recent_memories",
|
||||
"question",
|
||||
}
|
||||
|
||||
_DEFAULT_SYSTEM_TEMPLATE = """
|
||||
You are a {{ role }}, {% if name %}named {{ name }}, {% endif %}your goal is {{ goal }}.
|
||||
Please think step by step to achieve the goal. You can use the resources given below.
|
||||
At the same time, please strictly abide by the constraints and specifications in IMPORTANT REMINDER.
|
||||
{% if resource_prompt %} {{ resource_prompt }} {% endif %}
|
||||
{% if expand_prompt %} {{ expand_prompt }} {% endif %}
|
||||
|
||||
*** IMPORTANT REMINDER ***
|
||||
{% if language == 'zh' %}
|
||||
Please answer in simplified Chinese.
|
||||
{% else %}
|
||||
Please answer in English.
|
||||
{% endif %}
|
||||
|
||||
{% if constraints %}
|
||||
{% for constraint in constraints %}
|
||||
{{ loop.index }}. {{ constraint }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if examples %}
|
||||
You can refer to the following examples:
|
||||
{{ examples }}
|
||||
{% endif %}
|
||||
|
||||
{% if out_schema %} {{ out_schema }} {% endif %}
|
||||
""" # noqa
|
||||
|
||||
_DEFAULT_USER_TEMPLATE = """
|
||||
{% if most_recent_memories %}
|
||||
Most recent observations:
|
||||
{{ most_recent_memories }}
|
||||
{% endif %}
|
||||
|
||||
{% if question %}
|
||||
Question: {{ question }}
|
||||
{% endif %}
|
||||
"""
|
||||
|
||||
_DEFAULT_SAVE_MEMORY_TEMPLATE = """
|
||||
{% if question %}Question: {{ question }} {% endif %}
|
||||
{% if thought %}Thought: {{ thought }} {% endif %}
|
||||
{% if action %}Action: {{ action }} {% endif %}
|
||||
{% if observation %}Observation: {{ observation }} {% endif %}
|
||||
"""
|
||||
|
||||
|
||||
class Profile(ABC):
|
||||
"""Profile interface."""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Return the name of current agent."""
|
||||
|
||||
@abstractmethod
|
||||
def get_role(self) -> str:
|
||||
"""Return the role of current agent."""
|
||||
|
||||
def get_goal(self) -> Optional[str]:
|
||||
"""Return the goal of current agent."""
|
||||
return None
|
||||
|
||||
def get_constraints(self) -> Optional[List[str]]:
|
||||
"""Return the constraints of current agent."""
|
||||
return None
|
||||
|
||||
def get_description(self) -> Optional[str]:
|
||||
"""Return the description of current agent.
|
||||
|
||||
It will not be used to generate prompt.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_expand_prompt(self) -> Optional[str]:
|
||||
"""Return the expand prompt of current agent."""
|
||||
return None
|
||||
|
||||
def get_examples(self) -> Optional[str]:
|
||||
"""Return the examples of current agent."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_system_prompt_template(self) -> str:
|
||||
"""Return the prompt template of current agent."""
|
||||
|
||||
@abstractmethod
|
||||
def get_user_prompt_template(self) -> str:
|
||||
"""Return the user prompt template of current agent."""
|
||||
|
||||
@abstractmethod
|
||||
def get_save_memory_template(self) -> str:
|
||||
"""Return the save memory template of current agent."""
|
||||
|
||||
|
||||
class DefaultProfile(BaseModel, Profile):
|
||||
"""Default profile."""
|
||||
|
||||
name: str = Field("", description="The name of the agent.")
|
||||
role: str = Field("", description="The role of the agent.")
|
||||
goal: Optional[str] = Field(None, description="The goal of the agent.")
|
||||
constraints: Optional[List[str]] = Field(
|
||||
None, description="The constraints of the agent."
|
||||
)
|
||||
|
||||
desc: Optional[str] = Field(
|
||||
None, description="The description of the agent, not used to generate prompt."
|
||||
)
|
||||
|
||||
expand_prompt: Optional[str] = Field(
|
||||
None, description="The expand prompt of the agent."
|
||||
)
|
||||
|
||||
examples: Optional[str] = Field(
|
||||
None, description="The examples of the agent prompt."
|
||||
)
|
||||
|
||||
system_prompt_template: str = Field(
|
||||
_DEFAULT_SYSTEM_TEMPLATE, description="The system prompt template of the agent."
|
||||
)
|
||||
user_prompt_template: str = Field(
|
||||
_DEFAULT_USER_TEMPLATE, description="The user prompt template of the agent."
|
||||
)
|
||||
|
||||
save_memory_template: str = Field(
|
||||
_DEFAULT_SAVE_MEMORY_TEMPLATE,
|
||||
description="The save memory template of the agent.",
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the name of current agent."""
|
||||
return self.name
|
||||
|
||||
def get_role(self) -> str:
|
||||
"""Return the role of current agent."""
|
||||
return self.role
|
||||
|
||||
def get_goal(self) -> Optional[str]:
|
||||
"""Return the goal of current agent."""
|
||||
return self.goal
|
||||
|
||||
def get_constraints(self) -> Optional[List[str]]:
|
||||
"""Return the constraints of current agent."""
|
||||
return self.constraints
|
||||
|
||||
def get_description(self) -> Optional[str]:
|
||||
"""Return the description of current agent.
|
||||
|
||||
It will not be used to generate prompt.
|
||||
"""
|
||||
return self.desc
|
||||
|
||||
def get_expand_prompt(self) -> Optional[str]:
|
||||
"""Return the expand prompt of current agent."""
|
||||
return self.expand_prompt
|
||||
|
||||
def get_examples(self) -> Optional[str]:
|
||||
"""Return the examples of current agent."""
|
||||
return self.examples
|
||||
|
||||
def get_system_prompt_template(self) -> str:
|
||||
"""Return the prompt template of current agent."""
|
||||
return self.system_prompt_template
|
||||
|
||||
def get_user_prompt_template(self) -> str:
|
||||
"""Return the user prompt template of current agent."""
|
||||
return self.user_prompt_template
|
||||
|
||||
def get_save_memory_template(self) -> str:
|
||||
"""Return the save memory template of current agent."""
|
||||
return self.save_memory_template
|
||||
|
||||
|
||||
class ProfileFactory:
|
||||
"""Profile factory interface.
|
||||
|
||||
It is used to create a profile.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile."""
|
||||
|
||||
|
||||
class LLMProfileFactory(ProfileFactory):
|
||||
"""Create a profile by LLM.
|
||||
|
||||
Based on LLM automatic generation, it usually specifies the rules of the generation
|
||||
configuration first, clarifies the composition and attributes of the agent
|
||||
configuration in the target population, and then gives a small number of samples,
|
||||
and finally LLM generates the configuration of all agents.
|
||||
"""
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile by LLM.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetProfileFactory(ProfileFactory):
|
||||
"""Create a profile by dataset.
|
||||
|
||||
Use existing data sets to generate agent configurations.
|
||||
|
||||
In some cases, the data set contains a large amount of information about real people
|
||||
, first organize the information about real people in the data set into a natural
|
||||
language prompt, which is then used to generate the agent configuration.
|
||||
"""
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile by dataset.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CompositeProfileFactory(ProfileFactory):
|
||||
"""Create a profile by combining multiple profile factories."""
|
||||
|
||||
def __init__(self, factories: List[ProfileFactory]):
|
||||
"""Create a composite profile factory."""
|
||||
self.factories = factories
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile by combining multiple profile factories.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProfileConfig(BaseModel):
|
||||
"""Profile configuration.
|
||||
|
||||
If factory is not specified, name and role must be specified.
|
||||
If factory is specified and name and role are also specified, the factory will be
|
||||
preferred.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile_id: int = Field(0, description="The profile ID.")
|
||||
name: str | ConfigInfo | None = DynConfig(..., description="The name of the agent.")
|
||||
role: str | ConfigInfo | None = DynConfig(..., description="The role of the agent.")
|
||||
goal: str | ConfigInfo | None = DynConfig(None, description="The goal.")
|
||||
constraints: List[str] | ConfigInfo | None = DynConfig(None, is_list=True)
|
||||
desc: str | ConfigInfo | None = DynConfig(
|
||||
None, description="The description of the agent."
|
||||
)
|
||||
expand_prompt: str | ConfigInfo | None = DynConfig(
|
||||
None, description="The expand prompt."
|
||||
)
|
||||
examples: str | ConfigInfo | None = DynConfig(None, description="The examples.")
|
||||
|
||||
system_prompt_template: str | ConfigInfo | None = DynConfig(
|
||||
_DEFAULT_SYSTEM_TEMPLATE, description="The prompt template."
|
||||
)
|
||||
user_prompt_template: str | ConfigInfo | None = DynConfig(
|
||||
_DEFAULT_USER_TEMPLATE, description="The user prompt template."
|
||||
)
|
||||
save_memory_template: str | ConfigInfo | None = DynConfig(
|
||||
_DEFAULT_SAVE_MEMORY_TEMPLATE, description="The save memory template."
|
||||
)
|
||||
factory: ProfileFactory | None = Field(None, description="The profile factory.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_before(cls, values):
|
||||
"""Check before validation."""
|
||||
if isinstance(values, dict):
|
||||
return values
|
||||
if values["factory"] is None:
|
||||
if values["name"] is None:
|
||||
raise ValueError("name must be specified if factory is not specified")
|
||||
if values["role"] is None:
|
||||
raise ValueError("role must be specified if factory is not specified")
|
||||
return values
|
||||
|
||||
@cachetools.cached(cachetools.TTLCache(maxsize=100, ttl=10))
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: Optional[int] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Profile:
|
||||
"""Create a profile.
|
||||
|
||||
If factory is specified, use the factory to create the profile.
|
||||
"""
|
||||
factory_profile = None
|
||||
if profile_id is None:
|
||||
profile_id = self.profile_id
|
||||
name = self.name
|
||||
role = self.role
|
||||
goal = self.goal
|
||||
constraints = self.constraints
|
||||
desc = self.desc
|
||||
expand_prompt = self.expand_prompt
|
||||
system_prompt_template = self.system_prompt_template
|
||||
user_prompt_template = self.user_prompt_template
|
||||
save_memory_template = self.save_memory_template
|
||||
examples = self.examples
|
||||
call_args = {
|
||||
"prefer_prompt_language": prefer_prompt_language,
|
||||
"prefer_model": prefer_model,
|
||||
}
|
||||
if isinstance(name, ConfigInfo):
|
||||
name = name.query(**call_args)
|
||||
if isinstance(role, ConfigInfo):
|
||||
role = role.query(**call_args)
|
||||
if isinstance(goal, ConfigInfo):
|
||||
goal = goal.query(**call_args)
|
||||
if isinstance(constraints, ConfigInfo):
|
||||
constraints = constraints.query(**call_args)
|
||||
if isinstance(desc, ConfigInfo):
|
||||
desc = desc.query(**call_args)
|
||||
if isinstance(expand_prompt, ConfigInfo):
|
||||
expand_prompt = expand_prompt.query(**call_args)
|
||||
if isinstance(examples, ConfigInfo):
|
||||
examples = examples.query(**call_args)
|
||||
if isinstance(system_prompt_template, ConfigInfo):
|
||||
system_prompt_template = system_prompt_template.query(**call_args)
|
||||
if isinstance(user_prompt_template, ConfigInfo):
|
||||
user_prompt_template = user_prompt_template.query(**call_args)
|
||||
if isinstance(save_memory_template, ConfigInfo):
|
||||
save_memory_template = save_memory_template.query(**call_args)
|
||||
|
||||
if self.factory is not None:
|
||||
factory_profile = self.factory.create_profile(
|
||||
profile_id,
|
||||
name,
|
||||
role,
|
||||
goal,
|
||||
prefer_prompt_language,
|
||||
prefer_model,
|
||||
)
|
||||
|
||||
if factory_profile is not None:
|
||||
return factory_profile
|
||||
return DefaultProfile(
|
||||
name=name,
|
||||
role=role,
|
||||
goal=goal,
|
||||
constraints=constraints,
|
||||
desc=desc,
|
||||
expand_prompt=expand_prompt,
|
||||
examples=examples,
|
||||
system_prompt_template=system_prompt_template,
|
||||
user_prompt_template=user_prompt_template,
|
||||
save_memory_template=save_memory_template,
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
"""Return the hash value."""
|
||||
return hash(self.profile_id)
|
@@ -1,112 +1,93 @@
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from jinja2.meta import find_undeclared_variables
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .action.base import ActionOutput
|
||||
from .memory.agent_memory import AgentMemory, AgentMemoryFragment
|
||||
from .memory.llm import LLMImportanceScorer, LLMInsightExtractor
|
||||
from .profile import Profile, ProfileConfig
|
||||
|
||||
|
||||
class Role(ABC, BaseModel):
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = ""
|
||||
name: str = ""
|
||||
resource_introduction: str = ""
|
||||
goal: str = ""
|
||||
|
||||
expand_prompt: str = ""
|
||||
profile: ProfileConfig = Field(
|
||||
...,
|
||||
description="The profile of the role.",
|
||||
)
|
||||
memory: AgentMemory = Field(default_factory=AgentMemory)
|
||||
|
||||
fixed_subgoal: Optional[str] = Field(None, description="Fixed subgoal")
|
||||
|
||||
constraints: List[str] = Field(default_factory=list, description="Constraints")
|
||||
examples: str = ""
|
||||
desc: str = ""
|
||||
language: str = "en"
|
||||
is_human: bool = False
|
||||
is_team: bool = False
|
||||
|
||||
def prompt_template(
|
||||
template_env: SandboxedEnvironment = Field(default_factory=SandboxedEnvironment)
|
||||
|
||||
async def build_prompt(
|
||||
self,
|
||||
specified_prompt: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
is_system: bool = True,
|
||||
most_recent_memories: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""Return the prompt template for the role.
|
||||
|
||||
Args:
|
||||
specified_prompt (str, optional): The specified prompt. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The prompt template.
|
||||
"""
|
||||
if specified_prompt:
|
||||
return specified_prompt
|
||||
|
||||
expand_prompt = self.expand_prompt if len(self.expand_prompt) > 0 else ""
|
||||
examples_prompt = (
|
||||
"You can refer to the following examples:\n"
|
||||
if len(self.examples) > 0
|
||||
else ""
|
||||
prompt_template = (
|
||||
self.system_prompt_template if is_system else self.user_prompt_template
|
||||
)
|
||||
examples = self.examples if len(self.examples) > 0 else ""
|
||||
template = (
|
||||
f"{self.role_prompt}\n"
|
||||
"Please think step by step to achieve the goal. You can use the resources "
|
||||
"given below. At the same time, please strictly abide by the constraints "
|
||||
"and specifications in IMPORTANT REMINDER.\n\n"
|
||||
f"{{resource_prompt}}\n\n"
|
||||
f"{expand_prompt}\n\n"
|
||||
"*** IMPORTANT REMINDER ***\n"
|
||||
f"{self.language_require_prompt}\n"
|
||||
f"{self.constraints_prompt}\n"
|
||||
f"{examples_prompt}{examples}\n\n"
|
||||
f"{{out_schema}}"
|
||||
)
|
||||
return template
|
||||
template_vars = self._get_template_variables(prompt_template)
|
||||
_sub_render_keys = {"role", "name", "goal", "expand_prompt", "constraints"}
|
||||
pass_vars = {
|
||||
"role": self.role,
|
||||
"name": self.name,
|
||||
"goal": self.goal,
|
||||
"expand_prompt": self.expand_prompt,
|
||||
"language": self.language,
|
||||
"constraints": self.constraints,
|
||||
"most_recent_memories": (
|
||||
most_recent_memories if most_recent_memories else None
|
||||
),
|
||||
"examples": self.examples,
|
||||
# "out_schema": out_schema if out_schema else None,
|
||||
# "resource_prompt": resource_prompt if resource_prompt else None,
|
||||
"question": question,
|
||||
}
|
||||
resource_vars = await self.generate_resource_variables(question)
|
||||
pass_vars.update(resource_vars)
|
||||
pass_vars.update(kwargs)
|
||||
filtered_data = {
|
||||
key: pass_vars[key] for key in template_vars if key in pass_vars
|
||||
}
|
||||
for key in filtered_data.keys():
|
||||
value = filtered_data[key]
|
||||
if key in _sub_render_keys and value:
|
||||
if isinstance(value, str):
|
||||
# Render the sub-template
|
||||
filtered_data[key] = self._render_template(value, **pass_vars)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, str):
|
||||
value[i] = self._render_template(item, **pass_vars)
|
||||
return self._render_template(prompt_template, **filtered_data)
|
||||
|
||||
@property
|
||||
def role_prompt(self) -> str:
|
||||
"""Return the role prompt.
|
||||
|
||||
You are a {self.profile}, named {self.name}, your goal is {self.goal}.
|
||||
|
||||
Returns:
|
||||
str: The role prompt.
|
||||
"""
|
||||
profile_prompt = f"You are a {self.profile},"
|
||||
name_prompt = f"named {self.name}," if len(self.name) > 0 else ""
|
||||
goal_prompt = f"your goal is {self.goal}"
|
||||
prompt = f"""{profile_prompt}{name_prompt}{goal_prompt}"""
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def constraints_prompt(self) -> str:
|
||||
"""Return the constraints prompt.
|
||||
|
||||
Return:
|
||||
str: The constraints prompt.
|
||||
"""
|
||||
if len(self.constraints) > 0:
|
||||
return "\n".join(
|
||||
f"{i + 1}. {item}" for i, item in enumerate(self.constraints)
|
||||
)
|
||||
return ""
|
||||
|
||||
@property
|
||||
def language_require_prompt(self) -> str:
|
||||
"""Return the language requirement prompt.
|
||||
|
||||
Returns:
|
||||
str: The language requirement prompt.
|
||||
"""
|
||||
if self.language == "zh":
|
||||
return "Please answer in simplified Chinese."
|
||||
else:
|
||||
return "Please answer in English."
|
||||
|
||||
@property
|
||||
def introduce(self) -> str:
|
||||
"""Introduce the role."""
|
||||
return self.desc
|
||||
async def generate_resource_variables(
|
||||
self, question: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate the resource variables."""
|
||||
return {}
|
||||
|
||||
def identity_check(self) -> None:
|
||||
"""Check the identity of the role."""
|
||||
@@ -114,12 +95,123 @@ class Role(ABC, BaseModel):
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Get the name of the role."""
|
||||
return self.name
|
||||
return self.current_profile.get_name()
|
||||
|
||||
def get_profile(self) -> str:
|
||||
"""Get the profile of the role."""
|
||||
return self.profile
|
||||
@property
|
||||
def current_profile(self) -> Profile:
|
||||
"""Return the current profile."""
|
||||
profile = self.profile.create_profile()
|
||||
return profile
|
||||
|
||||
def get_describe(self) -> str:
|
||||
"""Get the describe of the role."""
|
||||
return self.desc
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of the role."""
|
||||
return self.current_profile.get_name()
|
||||
|
||||
@property
|
||||
def examples(self) -> Optional[str]:
|
||||
"""Return the examples of the role."""
|
||||
return self.current_profile.get_examples()
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
"""Return the role of the role."""
|
||||
return self.current_profile.get_role()
|
||||
|
||||
@property
|
||||
def goal(self) -> Optional[str]:
|
||||
"""Return the goal of the role."""
|
||||
return self.current_profile.get_goal()
|
||||
|
||||
@property
|
||||
def constraints(self) -> Optional[List[str]]:
|
||||
"""Return the constraints of the role."""
|
||||
return self.current_profile.get_constraints()
|
||||
|
||||
@property
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Return the description of the role."""
|
||||
return self.current_profile.get_description()
|
||||
|
||||
@property
|
||||
def expand_prompt(self) -> Optional[str]:
|
||||
"""Return the expand prompt of the role."""
|
||||
return self.current_profile.get_expand_prompt()
|
||||
|
||||
@property
|
||||
def system_prompt_template(self) -> str:
|
||||
"""Return the current system prompt template."""
|
||||
return self.current_profile.get_system_prompt_template()
|
||||
|
||||
@property
|
||||
def user_prompt_template(self) -> str:
|
||||
"""Return the current user prompt template."""
|
||||
return self.current_profile.get_user_prompt_template()
|
||||
|
||||
@property
|
||||
def save_memory_template(self) -> str:
|
||||
"""Return the current save memory template."""
|
||||
return self.current_profile.get_save_memory_template()
|
||||
|
||||
def _get_template_variables(self, template: str) -> Set[str]:
|
||||
parsed_content = self.template_env.parse(template)
|
||||
return find_undeclared_variables(parsed_content)
|
||||
|
||||
def _render_template(self, template: str, **kwargs):
|
||||
r_template = self.template_env.from_string(template)
|
||||
return r_template.render(**kwargs)
|
||||
|
||||
@property
|
||||
def memory_importance_scorer(self) -> Optional[LLMImportanceScorer]:
|
||||
"""Create the memory importance scorer.
|
||||
|
||||
The memory importance scorer is used to score the importance of a memory
|
||||
fragment.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def memory_insight_extractor(self) -> Optional[LLMInsightExtractor]:
|
||||
"""Create the memory insight extractor.
|
||||
|
||||
The memory insight extractor is used to extract a high-level insight from a
|
||||
memory fragment.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def read_memories(
|
||||
self,
|
||||
question: str,
|
||||
) -> str:
|
||||
"""Read the memories from the memory."""
|
||||
memories = await self.memory.read(question)
|
||||
recent_messages = [m.raw_observation for m in memories]
|
||||
return "".join(recent_messages)
|
||||
|
||||
async def save_to_memory(
|
||||
self,
|
||||
question: str,
|
||||
ai_message: str,
|
||||
action_output: Optional[ActionOutput] = None,
|
||||
check_pass: bool = True,
|
||||
check_fail_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Save the role to the memory."""
|
||||
if not action_output:
|
||||
raise ValueError("Action output is required to save to memory.")
|
||||
|
||||
mem_thoughts = action_output.thoughts or ai_message
|
||||
observation = action_output.observations or action_output.content
|
||||
if not check_pass and check_fail_reason:
|
||||
observation += "\n" + check_fail_reason
|
||||
|
||||
memory_map = {
|
||||
"question": question,
|
||||
"thought": mem_thoughts,
|
||||
"action": action_output.action,
|
||||
"observation": observation,
|
||||
}
|
||||
save_memory_template = self.save_memory_template
|
||||
memory_content = self._render_template(save_memory_template, **memory_map)
|
||||
fragment = AgentMemoryFragment(memory_content)
|
||||
await self.memory.write(fragment)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""A proxy agent for the user."""
|
||||
from .base_agent import ConversableAgent
|
||||
from .profile import ProfileConfig
|
||||
|
||||
|
||||
class UserProxyAgent(ConversableAgent):
|
||||
@@ -8,12 +9,13 @@ class UserProxyAgent(ConversableAgent):
|
||||
That can execute code and provide feedback to the other agents.
|
||||
"""
|
||||
|
||||
name: str = "User"
|
||||
profile: str = "Human"
|
||||
|
||||
desc: str = (
|
||||
"A human admin. Interact with the planner to discuss the plan. "
|
||||
"Plan execution needs to be approved by this admin."
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="User",
|
||||
role="Human",
|
||||
description=(
|
||||
"A human admin. Interact with the planner to discuss the plan. "
|
||||
"Plan execution needs to be approved by this admin."
|
||||
),
|
||||
)
|
||||
|
||||
is_human: bool = True
|
||||
|
@@ -1,9 +1,10 @@
|
||||
"""Indicator Assistant Agent."""
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..actions.indicator_action import IndicatorAction
|
||||
import logging
|
||||
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from .actions.indicator_action import IndicatorAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,29 +12,48 @@ logger = logging.getLogger(__name__)
|
||||
class IndicatorAssistantAgent(ConversableAgent):
|
||||
"""Indicator Assistant Agent."""
|
||||
|
||||
name = "Indicator"
|
||||
profile: str = "Indicator"
|
||||
goal: str = (
|
||||
"Summarize answer summaries based on user questions from provided "
|
||||
"resource information or from historical conversation memories."
|
||||
)
|
||||
|
||||
constraints: List[str] = [
|
||||
"Prioritize the summary of answers to user questions from the improved resource"
|
||||
" text. If no relevant information is found, summarize it from the historical"
|
||||
" dialogue memory given. It is forbidden to make up your own.",
|
||||
"You need to first detect user's question that you need to answer with your "
|
||||
"summarization.",
|
||||
"Extract the provided text content used for summarization.",
|
||||
"Then you need to summarize the extracted text content.",
|
||||
"Output the content of summarization ONLY related to user's question. The "
|
||||
"output language must be the same to user's question language.",
|
||||
"If you think the provided text content is not related to user questions at "
|
||||
"all, ONLY output 'Did not find the information you want.'!!.",
|
||||
]
|
||||
desc: str = (
|
||||
"You can summarize provided text content according to user's questions "
|
||||
"and output the summarization."
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Indicator",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"Indicator",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Summarize answer summaries based on user questions from provided "
|
||||
"resource information or from historical conversation memories.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Prioritize the summary of answers to user questions from the "
|
||||
"improved resource text. If no relevant information is found, "
|
||||
"summarize it from the historical dialogue memory given. It is "
|
||||
"forbidden to make up your own.",
|
||||
"You need to first detect user's question that you need to answer "
|
||||
"with your summarization.",
|
||||
"Extract the provided text content used for summarization.",
|
||||
"Then you need to summarize the extracted text content.",
|
||||
"Output the content of summarization ONLY related to user's question. "
|
||||
"The output language must be the same to user's question language.",
|
||||
"If you think the provided text content is not related to user "
|
||||
"questions at all, ONLY output 'Did not find the information you "
|
||||
"want.'!!.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"You can summarize provided text content according to user's questions "
|
||||
"and output the summarization.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_indicator_assistant_agent_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
1
dbgpt/agent/expand/actions/__init__.py
Normal file
1
dbgpt/agent/expand/actions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Actions of expand Agents."""
|
@@ -1,4 +1,5 @@
|
||||
"""Chart Action for SQL execution and rendering."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
@@ -6,9 +7,9 @@ from typing import Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_json
|
||||
from dbgpt.vis.tags.vis_chart import Vis, VisChart
|
||||
|
||||
from ..resource.resource_api import AgentResource, ResourceType
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from .action import Action, ActionOutput
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_db_api import ResourceDbClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Code Action Module."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -6,8 +7,8 @@ from dbgpt.util.code_utils import UNKNOWN, execute_code, extract_code, infer_lan
|
||||
from dbgpt.util.utils import colored
|
||||
from dbgpt.vis.tags.vis_code import Vis, VisCode
|
||||
|
||||
from ..resource.resource_api import AgentResource
|
||||
from .action import Action, ActionOutput
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,7 +74,13 @@ class CodeAction(Action[None]):
|
||||
if not self.render_protocol:
|
||||
raise NotImplementedError("The render_protocol should be implemented.")
|
||||
view = await self.render_protocol.display(content=param)
|
||||
return ActionOutput(is_exe_success=exit_success, content=content, view=view)
|
||||
return ActionOutput(
|
||||
is_exe_success=exit_success,
|
||||
content=content,
|
||||
view=view,
|
||||
thoughts=ai_message,
|
||||
observations=content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Code Action Run Failed!")
|
||||
return ActionOutput(
|
@@ -1,4 +1,5 @@
|
||||
"""Dashboard Action Module."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
@@ -6,9 +7,9 @@ from typing import List, Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard
|
||||
|
||||
from ..resource.resource_api import AgentResource, ResourceType
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from .action import Action, ActionOutput
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_db_api import ResourceDbClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@@ -7,9 +7,9 @@ from typing import Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
|
||||
|
||||
from ..core.schema import Status
|
||||
from ..resource.resource_api import AgentResource, ResourceType
|
||||
from .action import Action, ActionOutput
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...core.schema import Status
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Plugin Action Module."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
@@ -6,11 +7,11 @@ from typing import Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
|
||||
|
||||
from ..core.schema import Status
|
||||
from ..plugin.generator import PluginPromptGenerator
|
||||
from ..resource.resource_api import AgentResource, ResourceType
|
||||
from ..resource.resource_plugin_api import ResourcePluginClient
|
||||
from .action import Action, ActionOutput
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...core.schema import Status
|
||||
from ...plugin.generator import PluginPromptGenerator
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_plugin_api import ResourcePluginClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -144,7 +145,10 @@ class PluginAction(Action[PluginInput]):
|
||||
view = await self.render_protocol.display(content=plugin_param)
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=response_success, content=tool_result, view=view
|
||||
is_exe_success=response_success,
|
||||
content=tool_result,
|
||||
view=view,
|
||||
observations=tool_result,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Tool Action Run Failed!")
|
@@ -1,12 +1,14 @@
|
||||
"""Code Assistant Agent."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from dbgpt.core import ModelMessageRoleType
|
||||
from dbgpt.util.string_utils import str_to_bool
|
||||
|
||||
from ..actions.code_action import CodeAction
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from .actions.code_action import CodeAction
|
||||
|
||||
CHECK_RESULT_SYSTEM_MESSAGE = (
|
||||
"You are an expert in analyzing the results of task execution. Your responsibility "
|
||||
@@ -42,54 +44,75 @@ CHECK_RESULT_SYSTEM_MESSAGE = (
|
||||
class CodeAssistantAgent(ConversableAgent):
|
||||
"""Code Assistant Agent."""
|
||||
|
||||
name: str = "Turing"
|
||||
profile: str = "CodeEngineer"
|
||||
goal: str = (
|
||||
"Solve tasks using your coding and language skills.\n"
|
||||
"In the following cases, suggest python code (in a python coding block) or "
|
||||
"shell script (in a sh coding block) for the user to execute.\n"
|
||||
" 1. When you need to collect info, use the code to output the info you "
|
||||
"need, for example, browse or search the web, download/read a file, print the "
|
||||
"content of a webpage or a file, get the current date/time, check the "
|
||||
"operating system. After sufficient info is printed and the task is ready to be"
|
||||
" solved based on your language skill, you can solve the task by yourself.\n"
|
||||
" 2. When you need to perform some task with code, use the code to perform "
|
||||
"the task and output the result. Finish the task smartly."
|
||||
)
|
||||
constraints: List[str] = [
|
||||
"The user cannot provide any other feedback or perform any other action beyond"
|
||||
" executing the code you suggest. The user can't modify your code. So do not "
|
||||
"suggest incomplete code which requires users to modify. Don't use a code block"
|
||||
" if it's not intended to be executed by the user.Don't ask users to copy and "
|
||||
"paste results. Instead, the 'Print' function must be used for output when "
|
||||
"relevant.",
|
||||
"When using code, you must indicate the script type in the code block. Please "
|
||||
"don't include multiple code blocks in one response.",
|
||||
"If you want the user to save the code in a file before executing it, put "
|
||||
"# filename: <filename> inside the code block as the first line.",
|
||||
"If you receive user input that indicates an error in the code execution, fix "
|
||||
"the error and output the complete code again. It is recommended to use the "
|
||||
"complete code rather than partial code or code changes. If the error cannot be"
|
||||
" fixed, or the task is not resolved even after the code executes successfully,"
|
||||
" analyze the problem, revisit your assumptions, gather additional information"
|
||||
" you need from historical conversation records, and consider trying a "
|
||||
"different approach.",
|
||||
"Unless necessary, give priority to solving problems with python code. If it "
|
||||
"involves downloading files or storing data locally, please use 'Print' to "
|
||||
"output the full file path of the stored data and a brief introduction to the "
|
||||
"data.",
|
||||
"The output content of the 'print' function will be passed to other LLM agents "
|
||||
"as dependent data. Please control the length of the output content of the "
|
||||
"'print' function. The 'print' function only outputs part of the key data "
|
||||
"information that is relied on, and is as concise as possible.",
|
||||
"The code is executed without user participation. It is forbidden to use "
|
||||
"methods that will block the process or need to be shut down, such as the "
|
||||
"plt.show() method of matplotlib.pyplot as plt.",
|
||||
"It is prohibited to fabricate non-existent data to achieve goals.",
|
||||
]
|
||||
desc: str = (
|
||||
"Can independently write and execute python/shell code to solve various"
|
||||
" problems"
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Turing",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_code_assistant_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"CodeEngineer",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_code_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Solve tasks using your coding and language skills.\n"
|
||||
"In the following cases, suggest python code (in a python coding block) or "
|
||||
"shell script (in a sh coding block) for the user to execute.\n"
|
||||
" 1. When you need to collect info, use the code to output the info you "
|
||||
"need, for example, browse or search the web, download/read a file, print "
|
||||
"the content of a webpage or a file, get the current date/time, check the "
|
||||
"operating system. After sufficient info is printed and the task is ready "
|
||||
"to be solved based on your language skill, you can solve the task by "
|
||||
"yourself.\n"
|
||||
" 2. When you need to perform some task with code, use the code to "
|
||||
"perform the task and output the result. Finish the task smartly.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_code_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"The user cannot provide any other feedback or perform any other "
|
||||
"action beyond executing the code you suggest. The user can't modify "
|
||||
"your code. So do not suggest incomplete code which requires users to "
|
||||
"modify. Don't use a code block if it's not intended to be executed "
|
||||
"by the user.Don't ask users to copy and paste results. Instead, "
|
||||
"the 'Print' function must be used for output when relevant.",
|
||||
"When using code, you must indicate the script type in the code block. "
|
||||
"Please don't include multiple code blocks in one response.",
|
||||
"If you want the user to save the code in a file before executing it, "
|
||||
"put # filename: <filename> inside the code block as the first line.",
|
||||
"If you receive user input that indicates an error in the code "
|
||||
"execution, fix the error and output the complete code again. It is "
|
||||
"recommended to use the complete code rather than partial code or "
|
||||
"code changes. If the error cannot be fixed, or the task is not "
|
||||
"resolved even after the code executes successfully, analyze the "
|
||||
"problem, revisit your assumptions, gather additional information you "
|
||||
"need from historical conversation records, and consider trying a "
|
||||
"different approach.",
|
||||
"Unless necessary, give priority to solving problems with python "
|
||||
"code. If it involves downloading files or storing data locally, "
|
||||
"please use 'Print' to output the full file path of the stored data "
|
||||
"and a brief introduction to the data.",
|
||||
"The output content of the 'print' function will be passed to other "
|
||||
"LLM agents as dependent data. Please control the length of the "
|
||||
"output content of the 'print' function. The 'print' function only "
|
||||
"outputs part of the key data information that is relied on, "
|
||||
"and is as concise as possible.",
|
||||
"The code is executed without user participation. It is forbidden to "
|
||||
"use methods that will block the process or need to be shut down, "
|
||||
"such as the plt.show() method of matplotlib.pyplot as plt.",
|
||||
"It is prohibited to fabricate non-existent data to achieve goals.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_code_assistant_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"Can independently write and execute python/shell code to solve various"
|
||||
" problems",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_code_assistant_agent_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@@ -1,37 +1,54 @@
|
||||
"""Dashboard Assistant Agent."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..actions.dashboard_action import DashboardAction
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from .actions.dashboard_action import DashboardAction
|
||||
|
||||
|
||||
class DashboardAssistantAgent(ConversableAgent):
|
||||
"""Dashboard Assistant Agent."""
|
||||
|
||||
name: str = "Visionary"
|
||||
|
||||
profile: str = "Reporter"
|
||||
goal: str = (
|
||||
"Read the provided historical messages, collect various analysis SQLs "
|
||||
"from them, and assemble them into professional reports."
|
||||
)
|
||||
constraints: List[str] = [
|
||||
"You are only responsible for collecting and sorting out the analysis SQL that"
|
||||
" already exists in historical messages, and do not generate any analysis sql "
|
||||
"yourself.",
|
||||
"In order to build a report with rich display types, you can appropriately "
|
||||
"adjust the display type of the charts you collect so that you can build a "
|
||||
"better report. Of course, you can choose from the following available "
|
||||
"display types: {display_type}",
|
||||
"Please read and completely collect all analysis sql in the historical "
|
||||
"conversation, and do not omit or modify the content of the analysis sql.",
|
||||
]
|
||||
desc: str = (
|
||||
"Observe and organize various analysis results and construct "
|
||||
"professional reports"
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Visionary",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"Reporter",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Read the provided historical messages, collect various analysis SQLs "
|
||||
"from them, and assemble them into professional reports.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"You are only responsible for collecting and sorting out the analysis "
|
||||
"SQL that already exists in historical messages, and do not generate "
|
||||
"any analysis sql yourself.",
|
||||
"In order to build a report with rich display types, you can "
|
||||
"appropriately adjust the display type of the charts you collect so "
|
||||
"that you can build a better report. Of course, you can choose from "
|
||||
"the following available display types: {{ display_type }}",
|
||||
"Please read and completely collect all analysis sql in the "
|
||||
"historical conversation, and do not omit or modify the content of "
|
||||
"the analysis sql.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"Observe and organize various analysis results and construct "
|
||||
"professional reports",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@@ -2,14 +2,15 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, cast
|
||||
from typing import Optional, Tuple, cast
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from ..actions.chart_action import ChartAction
|
||||
from ..core.action.base import ActionOutput
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..resource.resource_api import ResourceType
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from .actions.chart_action import ChartAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,31 +18,53 @@ logger = logging.getLogger(__name__)
|
||||
class DataScientistAgent(ConversableAgent):
|
||||
"""Data Scientist Agent."""
|
||||
|
||||
name: str = "Edgar"
|
||||
profile: str = "DataScientist"
|
||||
goal: str = (
|
||||
"Use correct {dialect} SQL to analyze and solve tasks based on the data"
|
||||
" structure information of the database given in the resource."
|
||||
)
|
||||
constraints: List[str] = [
|
||||
"Please check the generated SQL carefully. Please strictly abide by the data "
|
||||
"structure definition given. It is prohibited to use non-existent fields and "
|
||||
"data values. Do not use fields from table A to table B. You can perform "
|
||||
"multi-table related queries.",
|
||||
"If the data and fields that need to be analyzed in the target are in different"
|
||||
" tables, it is recommended to use multi-table correlation queries first, and "
|
||||
"pay attention to the correlation between multiple table structures.",
|
||||
"It is forbidden to construct data by yourself as a query condition. If you "
|
||||
"want to query a specific field, if the value of the field is provided, then "
|
||||
"you can perform a group statistical query on the field.",
|
||||
"Please select an appropriate one from the supported display methods for data "
|
||||
"display. If no suitable display type is found, table display is used by "
|
||||
"default. Supported display types: \n {display_type}",
|
||||
]
|
||||
desc: str = (
|
||||
"Use database resources to conduct data analysis, analyze SQL, and "
|
||||
"provide recommended rendering methods."
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Edgar",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"DataScientist",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Use correct {{ dialect }} SQL to analyze and solve tasks based on the data"
|
||||
" structure information of the database given in the resource.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Please check the generated SQL carefully. Please strictly abide by "
|
||||
"the data structure definition given. It is prohibited to use "
|
||||
"non-existent fields and data values. Do not use fields from table A "
|
||||
"to table B. You can perform multi-table related queries.",
|
||||
"If the data and fields that need to be analyzed in the target are in "
|
||||
"different tables, it is recommended to use multi-table correlation "
|
||||
"queries first, and pay attention to the correlation between multiple "
|
||||
"table structures.",
|
||||
"It is forbidden to construct data by yourself as a query condition. "
|
||||
"If you want to query a specific field, if the value of the field is "
|
||||
"provided, then you can perform a group statistical query on the "
|
||||
"field.",
|
||||
"Please select an appropriate one from the supported display methods "
|
||||
"for data display. If no suitable display type is found, "
|
||||
"table display is used by default. Supported display types: \n"
|
||||
"{{ display_type }}",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"Use database resources to conduct data analysis, analyze SQL, and provide "
|
||||
"recommended rendering methods.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_dashboard_assistant_agent_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
max_retry_count: int = 5
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@@ -1,12 +1,14 @@
|
||||
"""Plugin Assistant Agent."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..actions.plugin_action import PluginAction
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..plugin.generator import PluginPromptGenerator
|
||||
from ..resource.resource_api import ResourceType
|
||||
from ..resource.resource_plugin_api import ResourcePluginClient
|
||||
from .actions.plugin_action import PluginAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,23 +18,42 @@ class PluginAssistantAgent(ConversableAgent):
|
||||
|
||||
plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
|
||||
name: str = "LuBan"
|
||||
profile: str = "ToolExpert"
|
||||
goal: str = (
|
||||
"Read and understand the tool information given in the resources below to "
|
||||
"understand their capabilities and how to use them,and choosing the right tools"
|
||||
" to achieve the user's goals."
|
||||
)
|
||||
constraints: List[str] = [
|
||||
"Please read the parameter definition of the tool carefully and extract the "
|
||||
"specific parameters required to execute the tool from the user goal.",
|
||||
"Please output the selected tool name and specific parameter information in "
|
||||
"json format according to the following required format. If there is an "
|
||||
"example, please refer to the sample format output.",
|
||||
]
|
||||
desc: str = (
|
||||
"You can use the following tools to complete the task objectives, tool "
|
||||
"information: {tool_infos}"
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"LuBan",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"ToolExpert",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Read and understand the tool information given in the resources "
|
||||
"below to understand their capabilities and how to use them,and choosing "
|
||||
"the right tools to achieve the user's goals.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Please read the parameter definition of the tool carefully and extract"
|
||||
" the specific parameters required to execute the tool from the user "
|
||||
"goal.",
|
||||
"Please output the selected tool name and specific parameter "
|
||||
"information in json format according to the following required format."
|
||||
" If there is an example, please refer to the sample format output.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"You can use the following tools to complete the task objectives, "
|
||||
"tool information: {tool_infos}",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_plugin_assistant_agent_desc",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -40,14 +61,14 @@ class PluginAssistantAgent(ConversableAgent):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([PluginAction])
|
||||
|
||||
@property
|
||||
def introduce(self, **kwargs) -> str:
|
||||
"""Introduce the agent."""
|
||||
if not self.plugin_generator:
|
||||
raise ValueError("PluginGenerator is not loaded.")
|
||||
return self.desc.format(
|
||||
tool_infos=self.plugin_generator.generate_commands_string()
|
||||
)
|
||||
# @property
|
||||
# def introduce(self, **kwargs) -> str:
|
||||
# """Introduce the agent."""
|
||||
# if not self.plugin_generator:
|
||||
# raise ValueError("PluginGenerator is not loaded.")
|
||||
# return self.desc.format(
|
||||
# tool_infos=self.plugin_generator.generate_commands_string()
|
||||
# )
|
||||
|
||||
async def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Retrieve Summary Assistant Agent."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
@@ -9,9 +10,10 @@ from urllib.parse import urlparse
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.core import ModelMessageRoleType
|
||||
|
||||
from ..actions.action import Action, ActionOutput
|
||||
from ..core.action.base import Action, ActionOutput
|
||||
from ..core.agent import Agent, AgentMessage, AgentReviewInfo
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import ProfileConfig
|
||||
from ..resource.resource_api import AgentResource
|
||||
from ..util.cmp import cmp_string_equal
|
||||
|
||||
@@ -86,18 +88,7 @@ class RetrieveSummaryAssistantAgent(ConversableAgent):
|
||||
including suggesting python code blocks and debugging.
|
||||
"""
|
||||
|
||||
goal = (
|
||||
"You're an extraction expert. You need to extract Please complete this task "
|
||||
"step by step following instructions below:\n"
|
||||
" 1. You need to first ONLY extract user's question that you need to answer "
|
||||
"without ANY file paths and URLs. \n"
|
||||
" 2. Extract the provided file paths and URLs.\n"
|
||||
" 3. Construct the extracted file paths and URLs as a list of strings.\n"
|
||||
" 4. ONLY output the extracted results with the following json format: "
|
||||
"{response}."
|
||||
)
|
||||
|
||||
PROMPT_QA = (
|
||||
PROMPT_QA: str = (
|
||||
"You are a great summary writer to summarize the provided text content "
|
||||
"according to user questions.\n"
|
||||
"User's Question is: {input_question}\n\n"
|
||||
@@ -118,7 +109,7 @@ class RetrieveSummaryAssistantAgent(ConversableAgent):
|
||||
"If the provided text content CAN NOT ANSWER user's question, ONLY output "
|
||||
"'NO RELATIONSHIP.UPDATE TEXT CONTENT.'!!."
|
||||
)
|
||||
CHECK_RESULT_SYSTEM_MESSAGE = (
|
||||
CHECK_RESULT_SYSTEM_MESSAGE: str = (
|
||||
"You are an expert in analyzing the results of a summary task."
|
||||
"Your responsibility is to check whether the summary results can summarize the "
|
||||
"input provided by the user, and then make a judgment. You need to answer "
|
||||
@@ -131,20 +122,30 @@ class RetrieveSummaryAssistantAgent(ConversableAgent):
|
||||
"not summarized. TERMINATE"
|
||||
)
|
||||
|
||||
DEFAULT_DESCRIBE = (
|
||||
DEFAULT_DESCRIBE: str = (
|
||||
"Summarize provided content according to user's questions and "
|
||||
"the provided file paths."
|
||||
)
|
||||
|
||||
name = "RetrieveSummarizer"
|
||||
desc = DEFAULT_DESCRIBE
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="RetrieveSummarizer",
|
||||
role="Assistant",
|
||||
goal="You're an extraction expert. You need to extract Please complete this "
|
||||
"task step by step following instructions below:\n"
|
||||
" 1. You need to first ONLY extract user's question that you need to answer "
|
||||
"without ANY file paths and URLs. \n"
|
||||
" 2. Extract the provided file paths and URLs.\n"
|
||||
" 3. Construct the extracted file paths and URLs as a list of strings.\n"
|
||||
" 4. ONLY output the extracted results with the following json format: "
|
||||
"{{ response }}.",
|
||||
desc=DEFAULT_DESCRIBE,
|
||||
)
|
||||
|
||||
chunk_token_size: int = 4000
|
||||
chunk_mode: str = "multi_lines"
|
||||
|
||||
_model = "gpt-3.5-turbo-16k"
|
||||
_max_tokens = _get_max_tokens(_model)
|
||||
context_max_tokens = _max_tokens * 0.8
|
||||
_model: str = "gpt-3.5-turbo-16k"
|
||||
_max_tokens: int = _get_max_tokens(_model)
|
||||
context_max_tokens: int = int(_max_tokens * 0.8)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -174,12 +175,14 @@ class RetrieveSummaryAssistantAgent(ConversableAgent):
|
||||
reply_message: AgentMessage = self._init_reply_message(
|
||||
received_message=received_message
|
||||
)
|
||||
await self._system_message_assembly(
|
||||
received_message.content, reply_message.context
|
||||
)
|
||||
# 1.Think about how to do things
|
||||
llm_reply, model_name = await self.thinking(
|
||||
self._load_thinking_messages(received_message, sender, rely_messages)
|
||||
await self._load_thinking_messages(
|
||||
received_message,
|
||||
sender,
|
||||
rely_messages,
|
||||
context=reply_message.get_dict_context(),
|
||||
)
|
||||
)
|
||||
|
||||
if not llm_reply:
|
||||
@@ -454,16 +457,16 @@ class RetrieveSummaryAssistantAgent(ConversableAgent):
|
||||
" set to False."
|
||||
)
|
||||
must_break_at_empty_line = False
|
||||
chunks.append(prev) if len(
|
||||
prev
|
||||
) > 10 else None # don't add chunks less than 10 characters
|
||||
(
|
||||
chunks.append(prev) if len(prev) > 10 else None
|
||||
) # don't add chunks less than 10 characters
|
||||
lines = lines[cnt:]
|
||||
lines_tokens = lines_tokens[cnt:]
|
||||
sum_tokens = sum(lines_tokens)
|
||||
text_to_chunk = "\n".join(lines)
|
||||
chunks.append(text_to_chunk) if len(
|
||||
text_to_chunk
|
||||
) > 10 else None # don't add chunks less than 10 characters
|
||||
(
|
||||
chunks.append(text_to_chunk) if len(text_to_chunk) > 10 else None
|
||||
) # don't add chunks less than 10 characters
|
||||
return chunks
|
||||
|
||||
def _extract_text_from_pdf(self, file: str) -> str:
|
||||
|
@@ -1,9 +1,10 @@
|
||||
"""Summary Assistant Agent."""
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..actions.blank_action import BlankAction
|
||||
import logging
|
||||
|
||||
from ..core.action.blank_action import BlankAction
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,29 +12,48 @@ logger = logging.getLogger(__name__)
|
||||
class SummaryAssistantAgent(ConversableAgent):
|
||||
"""Summary Assistant Agent."""
|
||||
|
||||
name: str = "Aristotle"
|
||||
profile: str = "Summarizer"
|
||||
goal: str = (
|
||||
"Summarize answer summaries based on user questions from provided "
|
||||
"resource information or from historical conversation memories."
|
||||
)
|
||||
|
||||
constraints: List[str] = [
|
||||
"Prioritize the summary of answers to user questions from the improved resource"
|
||||
" text. If no relevant information is found, summarize it from the historical "
|
||||
"dialogue memory given. It is forbidden to make up your own.",
|
||||
"You need to first detect user's question that you need to answer with your"
|
||||
" summarization.",
|
||||
"Extract the provided text content used for summarization.",
|
||||
"Then you need to summarize the extracted text content.",
|
||||
"Output the content of summarization ONLY related to user's question. The "
|
||||
"output language must be the same to user's question language.",
|
||||
"If you think the provided text content is not related to user questions at "
|
||||
"all, ONLY output 'Did not find the information you want.'!!.",
|
||||
]
|
||||
desc: str = (
|
||||
"You can summarize provided text content according to user's questions"
|
||||
" and output the summarization."
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Aristotle",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_summary_assistant_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"Summarizer",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_summary_assistant_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Summarize answer summaries based on user questions from provided "
|
||||
"resource information or from historical conversation memories.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_summary_assistant_agent_profile_goal",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Prioritize the summary of answers to user questions from the improved "
|
||||
"resource text. If no relevant information is found, summarize it from "
|
||||
"the historical dialogue memory given. It is forbidden to make up your "
|
||||
"own.",
|
||||
"You need to first detect user's question that you need to answer with "
|
||||
"your summarization.",
|
||||
"Extract the provided text content used for summarization.",
|
||||
"Then you need to summarize the extracted text content.",
|
||||
"Output the content of summarization ONLY related to user's question. "
|
||||
"The output language must be the same to user's question language.",
|
||||
"If you think the provided text content is not related to user "
|
||||
"questions at all, ONLY output 'Did not find the information you "
|
||||
"want.'!!.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_summary_assistant_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"You can summarize provided text content according to user's questions"
|
||||
" and output the summarization.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_expand_summary_assistant_agent_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@@ -1 +0,0 @@
|
||||
"""Memory module for agents."""
|
@@ -1,138 +0,0 @@
|
||||
"""Planner Agent."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from .plan_action import PlanAction
|
||||
|
||||
|
||||
class PlannerAgent(ConversableAgent):
|
||||
"""Planner Agent.
|
||||
|
||||
Planner agent, realizing task goal planning decomposition through LLM.
|
||||
"""
|
||||
|
||||
agents: List[ConversableAgent] = Field(default_factory=list)
|
||||
|
||||
profile: str = "Planner"
|
||||
goal_zh: str = (
|
||||
"理解下面每个智能体(agent)和他们的能力,使用给出的资源,通过协调智能体来解决"
|
||||
"用户问题。 请发挥你LLM的知识和理解能力,理解用户问题的意图和目标,生成一个可以在没有用户帮助"
|
||||
"下,由智能体协作完成目标的任务计划。"
|
||||
)
|
||||
goal: str = (
|
||||
"Understand each of the following intelligent agents and their "
|
||||
"capabilities, using the provided resources, solve user problems by "
|
||||
"coordinating intelligent agents. Please utilize your LLM's knowledge "
|
||||
"and understanding ability to comprehend the intent and goals of the "
|
||||
"user's problem, generating a task plan that can be completed through"
|
||||
" the collaboration of intelligent agents without user assistance."
|
||||
)
|
||||
expand_prompt_zh: str = "可用智能体(agent):\n {agents}"
|
||||
expand_prompt: str = "Available Intelligent Agents:\n {agents}"
|
||||
|
||||
constraints_zh: List[str] = [
|
||||
"任务计划的每个步骤都应该是为了推进解决用户目标而存在,不要生成无意义的任务步骤,确保每个步骤内目标明确内容完整。",
|
||||
"关注任务计划每个步骤的依赖关系和逻辑,被依赖步骤要考虑被依赖的数据,是否能基于当前目标得到,如果不能请在目标中提示要生成被依赖数据。",
|
||||
"每个步骤都是一个独立可完成的目标,一定要确保逻辑和信息完整,不要出现类似:"
|
||||
"'Analyze the retrieved issues data'这样目标不明确,不知道具体要分析啥内容的步骤",
|
||||
"请确保只使用上面提到的智能体,并且可以只使用其中需要的部分,严格根据描述能力和限制分配给合适的步骤,每个智能体都可以重复使用。",
|
||||
"根据用户目标的实际需要使用提供的资源来协助生成计划步骤,不要使用不需要的资源。",
|
||||
"每个步骤最好只使用一种资源完成一个子目标,如果当前目标可以分解为同类型的多个子任务,可以生成相互不依赖的并行任务。",
|
||||
"数据资源可以被合适的智能体加载使用,不用考虑数据资源的加载链接问题",
|
||||
"尽量合并有顺序依赖的连续相同步骤,如果用户目标无拆分必要,可以生成内容为用户目标的单步任务。",
|
||||
"仔细检查计划,确保计划完整的包含了用户问题所涉及的所有信息,并且最终能完成目标,确认每个步骤是否包含了需要用到的资源信息,如URL、资源名等. ",
|
||||
]
|
||||
constraints: List[str] = [
|
||||
"Every step of the task plan should exist to advance towards solving the user's"
|
||||
" goals. Do not generate meaningless task steps; ensure that each step has a "
|
||||
"clear goal and its content is complete.",
|
||||
"Pay attention to the dependencies and logic of each step in the task plan. "
|
||||
"For the steps that are depended upon, consider the data they depend on and "
|
||||
"whether it can be obtained based on the current goal. If it cannot be obtained"
|
||||
", please indicate in the goal that the dependent data needs to be generated.",
|
||||
"Each step must be an independently achievable goal. Ensure that the logic and"
|
||||
" information are complete. Avoid steps with unclear objectives, like "
|
||||
"'Analyze the retrieved issues data,' where it's unclear what specific content"
|
||||
" needs to be analyzed.",
|
||||
"Please ensure that only the intelligent agents mentioned above are used, and"
|
||||
" you may use only the necessary parts of them. Allocate them to appropriate "
|
||||
"steps strictly based on their described capabilities and limitations. Each "
|
||||
"intelligent agent can be reused.",
|
||||
"Utilize the provided resources to assist in generating the plan steps "
|
||||
"according to the actual needs of the user's goals. Do not use unnecessary "
|
||||
"resources.",
|
||||
"Each step should ideally use only one type of resource to accomplish a "
|
||||
"sub-goal. If the current goal can be broken down into multiple subtasks of the"
|
||||
" same type, you can create mutually independent parallel tasks.",
|
||||
"Data resources can be loaded and utilized by the appropriate intelligent "
|
||||
"agents without the need to consider the issues related to data loading links.",
|
||||
"Try to merge continuous steps that have sequential dependencies. If the "
|
||||
"user's goal does not require splitting, you can create a single-step task with"
|
||||
" content that is the user's goal.",
|
||||
"Carefully review the plan to ensure it comprehensively covers all information"
|
||||
" involved in the user's problem and can ultimately achieve the goal. Confirm"
|
||||
" whether each step includes the necessary resource information, such as URLs,"
|
||||
" resource names, etc.",
|
||||
]
|
||||
desc_zh: str = "你是一个任务规划专家!可以协调智能体,分配资源完成复杂的任务目标。"
|
||||
desc: str = (
|
||||
"You are a task planning expert! You can coordinate intelligent agents"
|
||||
" and allocate resources to achieve complex task goals."
|
||||
)
|
||||
|
||||
examples: str = """
|
||||
user:help me build a sales report summarizing our key metrics and trends
|
||||
assistants:[
|
||||
{{
|
||||
"serial_number": "1",
|
||||
"agent": "DataScientist",
|
||||
"content": "Retrieve total sales, average sales, and number of transactions grouped by "product_category"'.",
|
||||
"rely": ""
|
||||
}},
|
||||
{{
|
||||
"serial_number": "2",
|
||||
"agent": "DataScientist",
|
||||
"content": "Retrieve monthly sales and transaction number trends.",
|
||||
"rely": ""
|
||||
}},
|
||||
{{
|
||||
"serial_number": "3",
|
||||
"agent": "Reporter",
|
||||
"content": "Integrate analytical data into the format required to build sales reports.",
|
||||
"rely": "1,2"
|
||||
}}
|
||||
]
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new PlannerAgent instance."""
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([PlanAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage):
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
reply_message.context = {
|
||||
"agents": "\n".join(
|
||||
[f"- {item.profile}:{item.desc}" for item in self.agents]
|
||||
),
|
||||
}
|
||||
return reply_message
|
||||
|
||||
def bind_agents(self, agents: List[ConversableAgent]) -> ConversableAgent:
|
||||
"""Bind the agents to the planner agent."""
|
||||
self.agents = agents
|
||||
for agent in self.agents:
|
||||
if agent.resources and len(agent.resources) > 0:
|
||||
self.resources.extend(agent.resources)
|
||||
return self
|
||||
|
||||
def prepare_act_param(self) -> Dict[str, Any]:
|
||||
"""Prepare the parameters for the act method."""
|
||||
return {
|
||||
"context": self.not_null_agent_context,
|
||||
"plans_memory": self.memory.plans_memory,
|
||||
}
|
@@ -70,7 +70,7 @@ def execute_ai_response_json(
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments,
|
||||
arguments: Dict[str, Any],
|
||||
plugin_generator: PluginPromptGenerator,
|
||||
) -> Any:
|
||||
"""Execute the command and return the result.
|
||||
@@ -78,6 +78,7 @@ def execute_command(
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
plugin_generator (PluginPromptGenerator): The plugin generator
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
@@ -103,18 +104,23 @@ def execute_command(
|
||||
else:
|
||||
for command in plugin_generator.commands:
|
||||
if (
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
command_name == command.label.lower()
|
||||
or command_name == command.name.lower()
|
||||
):
|
||||
try:
|
||||
# 删除非定义参数
|
||||
# Delete non-defined parameters
|
||||
diff_ags = list(
|
||||
set(arguments.keys()).difference(set(command["args"].keys()))
|
||||
set(arguments.keys()).difference(set(command.args.keys()))
|
||||
)
|
||||
for arg_name in diff_ags:
|
||||
del arguments[arg_name]
|
||||
print(str(arguments))
|
||||
return command["function"](**arguments)
|
||||
func = command.function
|
||||
if not func:
|
||||
raise ExecutionCommandException(
|
||||
f"Function not found for command: {command_name}"
|
||||
)
|
||||
return func(**arguments)
|
||||
except Exception as e:
|
||||
raise ExecutionCommandException(f"Execution error: {str(e)}")
|
||||
raise NotCommandException("Invalid command: " + command_name)
|
||||
|
@@ -1,5 +1,34 @@
|
||||
"""A module for generating custom prompt strings."""
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .commands.command_manage import CommandRegistry
|
||||
|
||||
|
||||
class CommandEntry(BaseModel):
|
||||
"""CommandEntry class.
|
||||
|
||||
A class for storing information about a command.
|
||||
"""
|
||||
|
||||
label: str = Field(
|
||||
...,
|
||||
description="The label of the command.",
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="The name of the command.",
|
||||
)
|
||||
args: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="A dictionary containing argument names and their values.",
|
||||
)
|
||||
function: Optional[Callable] = Field(
|
||||
None,
|
||||
description="A callable function to be called when the command is executed.",
|
||||
)
|
||||
|
||||
|
||||
class PluginPromptGenerator:
|
||||
@@ -9,7 +38,7 @@ class PluginPromptGenerator:
|
||||
resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
"""Create a new PromptGenerator object.
|
||||
|
||||
Initialize the PromptGenerator object with empty lists of constraints,
|
||||
@@ -17,11 +46,44 @@ class PluginPromptGenerator:
|
||||
"""
|
||||
from .commands.command_manage import CommandRegistry
|
||||
|
||||
self.constraints: List[str] = []
|
||||
self.commands: List[Dict[str, Any]] = []
|
||||
self.resources: List[str] = []
|
||||
self.performance_evaluation: List[str] = []
|
||||
self.command_registry: CommandRegistry = CommandRegistry()
|
||||
self._constraints: List[str] = []
|
||||
self._commands: List[CommandEntry] = []
|
||||
self._resources: List[str] = []
|
||||
self._performance_evaluation: List[str] = []
|
||||
self._command_registry: CommandRegistry = CommandRegistry()
|
||||
|
||||
@property
|
||||
def constraints(self) -> List[str]:
|
||||
"""Return the list of constraints."""
|
||||
return self._constraints
|
||||
|
||||
@property
|
||||
def commands(self) -> List[CommandEntry]:
|
||||
"""Return the list of commands."""
|
||||
return self._commands
|
||||
|
||||
@property
|
||||
def resources(self) -> List[str]:
|
||||
"""Return the list of resources."""
|
||||
return self._resources
|
||||
|
||||
@property
|
||||
def performance_evaluation(self) -> List[str]:
|
||||
"""Return the list of performance evaluations."""
|
||||
return self._performance_evaluation
|
||||
|
||||
@property
|
||||
def command_registry(self) -> "CommandRegistry":
|
||||
"""Return the command registry."""
|
||||
return self._command_registry
|
||||
|
||||
def set_command_registry(self, command_registry: "CommandRegistry") -> None:
|
||||
"""Set the command registry.
|
||||
|
||||
Args:
|
||||
command_registry: CommandRegistry
|
||||
"""
|
||||
self._command_registry = command_registry
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""Add a constraint to the constraints list.
|
||||
@@ -29,13 +91,13 @@ class PluginPromptGenerator:
|
||||
Args:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
self.constraints.append(constraint)
|
||||
self._constraints.append(constraint)
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args=None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""Add a command to the commands.
|
||||
@@ -55,16 +117,15 @@ class PluginPromptGenerator:
|
||||
|
||||
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||
|
||||
command = {
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": command_args,
|
||||
"function": function,
|
||||
}
|
||||
command = CommandEntry(
|
||||
label=command_label,
|
||||
name=command_name,
|
||||
args=command_args,
|
||||
function=function,
|
||||
)
|
||||
self._commands.append(command)
|
||||
|
||||
self.commands.append(command)
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
def _generate_command_string(self, command: CommandEntry) -> str:
|
||||
"""
|
||||
Generate a formatted string representation of a command.
|
||||
|
||||
@@ -75,9 +136,9 @@ class PluginPromptGenerator:
|
||||
str: The formatted command string.
|
||||
"""
|
||||
args_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command["args"].items()
|
||||
f'"{key}": "{value}"' for key, value in command.args.items()
|
||||
)
|
||||
return f'"{command["name"]}": {command["label"]} , args: {args_string}'
|
||||
return f'"{command.name}": {command.label} , args: {args_string}'
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
@@ -86,7 +147,7 @@ class PluginPromptGenerator:
|
||||
Args:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
self.resources.append(resource)
|
||||
self._resources.append(resource)
|
||||
|
||||
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||
"""
|
||||
@@ -95,7 +156,7 @@ class PluginPromptGenerator:
|
||||
Args:
|
||||
evaluation (str): The evaluation item to be added.
|
||||
"""
|
||||
self.performance_evaluation.append(evaluation)
|
||||
self._performance_evaluation.append(evaluation)
|
||||
|
||||
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
|
||||
"""
|
||||
@@ -111,10 +172,10 @@ class PluginPromptGenerator:
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = []
|
||||
if self.command_registry:
|
||||
if self._command_registry:
|
||||
command_strings += [
|
||||
str(item)
|
||||
for item in self.command_registry.commands.values()
|
||||
for item in self._command_registry.commands.values()
|
||||
if item.enabled
|
||||
]
|
||||
# terminate command is added manually
|
||||
@@ -125,4 +186,4 @@ class PluginPromptGenerator:
|
||||
|
||||
def generate_commands_string(self) -> str:
|
||||
"""Return a formatted string representation of the commands list."""
|
||||
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||
return f"{self._generate_numbered_list(self._commands, item_type='command')}"
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Resource plugin client API."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Union, cast
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from ..plugin.commands.command_manage import execute_command
|
||||
from ..plugin.generator import PluginPromptGenerator
|
||||
@@ -48,8 +48,8 @@ class ResourcePluginClient(ResourceClient):
|
||||
async def execute_command(
|
||||
self,
|
||||
command_name: str,
|
||||
arguments: Optional[dict],
|
||||
plugin_generator: Optional[PluginPromptGenerator],
|
||||
arguments: Dict[str, Any],
|
||||
plugin_generator: PluginPromptGenerator,
|
||||
):
|
||||
"""Execute the command."""
|
||||
if plugin_generator is None:
|
||||
|
Reference in New Issue
Block a user