mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 18:40:10 +00:00
refactor(agent): Agent modular refactoring (#1487)
This commit is contained in:
@@ -1 +1,22 @@
|
||||
"""Core Module for the Agent."""
|
||||
"""Core Module for the Agent.
|
||||
|
||||
There are four modules in DB-GPT agent core according the paper
|
||||
`A survey on large language model based autonomous agents
|
||||
<https://link.springer.com/article/10.1007/s11704-024-40231-1>`
|
||||
by `Lei Wang, Chen Ma, Xueyang Feng, et al.`:
|
||||
|
||||
1. Profiling Module: The profiling module aims to indicate the profiles of the agent
|
||||
roles.
|
||||
|
||||
2. Memory Module: It stores information perceived from the environment and leverages
|
||||
the recorded memories to facilitate future actions.
|
||||
|
||||
3. Planning Module: When faced with a complex task, humans tend to deconstruct it into
|
||||
simpler subtasks and solve them individually. The planning module aims to empower the
|
||||
agents with such human capability, which is expected to make the agent behave more
|
||||
reasonably, powerfully, and reliably
|
||||
|
||||
4. Action Module: The action module is responsible for translating the agent’s
|
||||
decisions into specific outcomes. This module is located at the most downstream
|
||||
position and directly interacts with the environment.
|
||||
"""
|
||||
|
20
dbgpt/agent/core/action/__init__.py
Normal file
20
dbgpt/agent/core/action/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Action Module.
|
||||
|
||||
The action module is responsible for translating the agent’s decisions into specific
|
||||
outcomes. This module is located at the most downstream position and directly interacts
|
||||
with the environment. It is influenced by the profile, memory, and planning modules.
|
||||
|
||||
|
||||
The Goal Of The Action Module:
|
||||
--------
|
||||
1. Task Completion: Complete specific tasks, write a function in software development,
|
||||
and make an iron pick in the game.
|
||||
|
||||
2. Communication: Communicate with other agents.
|
||||
|
||||
3. Environment exploration: Explore unfamiliar environments to expand its perception
|
||||
and strike a balance between exploring and exploiting.
|
||||
"""
|
||||
|
||||
from .base import Action, ActionOutput # noqa: F401
|
||||
from .blank_action import BlankAction # noqa: F401
|
182
dbgpt/agent/core/action/base.py
Normal file
182
dbgpt/agent/core/action/base.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Base Action class for defining agent actions."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
field_default,
|
||||
field_description,
|
||||
model_fields,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
from dbgpt.vis.base import Vis
|
||||
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_loader import ResourceLoader
|
||||
|
||||
T = TypeVar("T", bound=Union[BaseModel, List[BaseModel], None])
|
||||
|
||||
JsonMessageType = Union[Dict[str, Any], List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class ActionOutput(BaseModel):
|
||||
"""Action output model."""
|
||||
|
||||
content: str
|
||||
is_exe_success: bool = True
|
||||
view: Optional[str] = None
|
||||
resource_type: Optional[str] = None
|
||||
resource_value: Optional[Any] = None
|
||||
action: Optional[str] = None
|
||||
thoughts: Optional[str] = None
|
||||
observations: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Any) -> Any:
|
||||
"""Pre-fill the values."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
is_exe_success = values.get("is_exe_success", True)
|
||||
if not is_exe_success and "observations" not in values:
|
||||
values["observations"] = values.get("content")
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls: Type["ActionOutput"], param: Optional[Dict]
|
||||
) -> Optional["ActionOutput"]:
|
||||
"""Convert dict to ActionOutput object."""
|
||||
if not param:
|
||||
return None
|
||||
return cls.parse_obj(param)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the object to a dictionary."""
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class Action(ABC, Generic[T]):
|
||||
"""Base Action class for defining agent actions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create an action."""
|
||||
self.resource_loader: Optional[ResourceLoader] = None
|
||||
|
||||
def init_resource_loader(self, resource_loader: Optional[ResourceLoader]):
|
||||
"""Initialize the resource loader."""
|
||||
self.resource_loader = resource_loader
|
||||
|
||||
@property
|
||||
def resource_need(self) -> Optional[ResourceType]:
|
||||
"""Return the resource type needed for the action."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def render_protocol(self) -> Optional[Vis]:
|
||||
"""Return the render protocol."""
|
||||
return None
|
||||
|
||||
def render_prompt(self) -> Optional[str]:
|
||||
"""Return the render prompt."""
|
||||
if self.render_protocol is None:
|
||||
return None
|
||||
else:
|
||||
return self.render_protocol.render_prompt()
|
||||
|
||||
def _create_example(
|
||||
self,
|
||||
model_type: Union[Type[BaseModel], List[Type[BaseModel]]],
|
||||
) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
|
||||
if model_type is None:
|
||||
return None
|
||||
origin = get_origin(model_type)
|
||||
args = get_args(model_type)
|
||||
if origin is None:
|
||||
example = {}
|
||||
single_model_type = cast(Type[BaseModel], model_type)
|
||||
for field_name, field in model_fields(single_model_type).items():
|
||||
description = field_description(field)
|
||||
default_value = field_default(field)
|
||||
if description:
|
||||
example[field_name] = description
|
||||
elif default_value:
|
||||
example[field_name] = default_value
|
||||
else:
|
||||
example[field_name] = ""
|
||||
return example
|
||||
elif origin is list or origin is List:
|
||||
element_type = cast(Type[BaseModel], args[0])
|
||||
if issubclass(element_type, BaseModel):
|
||||
list_example = self._create_example(element_type)
|
||||
typed_list_example = cast(Dict[str, Any], list_example)
|
||||
return [typed_list_example]
|
||||
else:
|
||||
raise TypeError("List elements must be BaseModel subclasses")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} is not an instance of BaseModel."
|
||||
)
|
||||
|
||||
@property
|
||||
def out_model_type(self) -> Optional[Union[Type[T], List[Type[T]]]]:
|
||||
"""Return the output model type."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
"""Return the AI output schema."""
|
||||
if self.out_model_type is None:
|
||||
return None
|
||||
|
||||
json_format_data = json.dumps(
|
||||
self._create_example(self.out_model_type), indent=2, ensure_ascii=False
|
||||
)
|
||||
return f"""Please response in the following json format:
|
||||
{json_format_data}
|
||||
Make sure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
||||
|
||||
def _ai_message_2_json(self, ai_message: str) -> JsonMessageType:
|
||||
json_objects = find_json_objects(ai_message)
|
||||
json_count = len(json_objects)
|
||||
if json_count != 1:
|
||||
raise ValueError("Unable to obtain valid output.")
|
||||
return json_objects[0]
|
||||
|
||||
def _input_convert(self, ai_message: str, cls: Type[T]) -> T:
|
||||
json_result = self._ai_message_2_json(ai_message)
|
||||
if get_origin(cls) == list:
|
||||
inner_type = get_args(cls)[0]
|
||||
typed_cls = cast(Type[BaseModel], inner_type)
|
||||
return [typed_cls.parse_obj(item) for item in json_result] # type: ignore
|
||||
else:
|
||||
typed_cls = cast(Type[BaseModel], cls)
|
||||
return typed_cls.parse_obj(json_result)
|
||||
|
||||
@abstractmethod
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
resource: Optional[AgentResource] = None,
|
||||
rely_action_out: Optional[ActionOutput] = None,
|
||||
need_vis_render: bool = True,
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
"""Perform the action."""
|
40
dbgpt/agent/core/action/blank_action.py
Normal file
40
dbgpt/agent/core/action/blank_action.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Blank Action for the Agent."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from ...resource.resource_api import AgentResource
|
||||
from .base import Action, ActionOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlankAction(Action):
|
||||
"""Blank action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a blank action."""
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
"""Return the AI output schema."""
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
resource: Optional[AgentResource] = None,
|
||||
rely_action_out: Optional[ActionOutput] = None,
|
||||
need_vis_render: bool = True,
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
"""Perform the action.
|
||||
|
||||
Just return the AI message.
|
||||
"""
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=ai_message,
|
||||
view=ai_message,
|
||||
)
|
@@ -9,9 +9,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from ..memory.gpts_memory import GptsMemory
|
||||
from ..resource.resource_loader import ResourceLoader
|
||||
from .action.base import ActionOutput
|
||||
from .memory.agent_memory import AgentMemory
|
||||
|
||||
|
||||
class Agent(ABC):
|
||||
@@ -160,17 +160,20 @@ class Agent(ABC):
|
||||
verification result.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Return name of the agent."""
|
||||
def name(self) -> str:
|
||||
"""Return the name of the agent."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_profile(self) -> str:
|
||||
"""Return profile of the agent."""
|
||||
def role(self) -> str:
|
||||
"""Return the role of the agent."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def get_describe(self) -> str:
|
||||
"""Return describe of the agent."""
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Return the description of the agent."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -204,7 +207,7 @@ class AgentGenerateContext:
|
||||
rely_messages: List[AgentMessage] = dataclasses.field(default_factory=list)
|
||||
final: Optional[bool] = True
|
||||
|
||||
memory: Optional[GptsMemory] = None
|
||||
memory: Optional[AgentMemory] = None
|
||||
agent_context: Optional[AgentContext] = None
|
||||
resource_loader: Optional[ResourceLoader] = None
|
||||
llm_client: Optional[LLMClient] = None
|
||||
@@ -302,3 +305,9 @@ class AgentMessage:
|
||||
role=self.role,
|
||||
success=self.success,
|
||||
)
|
||||
|
||||
def get_dict_context(self) -> Dict[str, Any]:
|
||||
"""Return the context as a dictionary."""
|
||||
if isinstance(self.context, dict):
|
||||
return self.context
|
||||
return {}
|
||||
|
@@ -18,7 +18,7 @@ def participant_roles(agents: List[Agent]) -> str:
|
||||
# Default to all agents registered
|
||||
roles = []
|
||||
for agent in agents:
|
||||
roles.append(f"{agent.get_name()}: {agent.get_describe()}")
|
||||
roles.append(f"{agent.name}: {agent.desc}")
|
||||
return "\n".join(roles)
|
||||
|
||||
|
||||
@@ -34,13 +34,13 @@ def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict:
|
||||
mentions = dict()
|
||||
for agent in agents:
|
||||
regex = (
|
||||
r"(?<=\W)" + re.escape(agent.get_name()) + r"(?=\W)"
|
||||
r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)"
|
||||
) # Finds agent mentions, taking word boundaries into account
|
||||
count = len(
|
||||
re.findall(regex, " " + message_content + " ")
|
||||
) # Pad the message to help with matching
|
||||
if count > 0:
|
||||
mentions[agent.get_name()] = count
|
||||
mentions[agent.name] = count
|
||||
return mentions
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class AgentManager(BaseComponent):
|
||||
) -> str:
|
||||
"""Register an agent."""
|
||||
inst = cls()
|
||||
profile = inst.get_profile()
|
||||
profile = inst.role
|
||||
if profile in self._agents and (
|
||||
profile in self._core_agents or not ignore_duplicate
|
||||
):
|
||||
@@ -110,13 +110,13 @@ class AgentManager(BaseComponent):
|
||||
|
||||
def get_describe_by_name(self, name: str) -> str:
|
||||
"""Return the description of an agent by name."""
|
||||
return self._agents[name][1].desc
|
||||
return self._agents[name][1].desc or ""
|
||||
|
||||
def all_agents(self) -> Dict[str, str]:
|
||||
"""Return a dictionary of all registered agents and their descriptions."""
|
||||
result = {}
|
||||
for name, value in self._agents.items():
|
||||
result[name] = value[1].desc
|
||||
result[name] = value[1].desc or ""
|
||||
return result
|
||||
|
||||
def list_agents(self):
|
||||
@@ -125,7 +125,7 @@ class AgentManager(BaseComponent):
|
||||
for name, value in self._agents.items():
|
||||
result.append(
|
||||
{
|
||||
"name": value[1].profile,
|
||||
"name": value[1].role,
|
||||
"desc": value[1].goal,
|
||||
}
|
||||
)
|
||||
|
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import LLMClient, ModelMessageRoleType
|
||||
@@ -11,14 +11,15 @@ from dbgpt.util.error_types import LLMChatError
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
from dbgpt.util.utils import colored
|
||||
|
||||
from ..actions.action import Action, ActionOutput
|
||||
from ..memory.base import GptsMessage
|
||||
from ..memory.gpts_memory import GptsMemory
|
||||
from ..resource.resource_api import AgentResource, ResourceClient
|
||||
from ..resource.resource_loader import ResourceLoader
|
||||
from ..util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ..util.llm.llm_client import AIWrapper
|
||||
from .action.base import Action, ActionOutput
|
||||
from .agent import Agent, AgentContext, AgentMessage, AgentReviewInfo
|
||||
from .llm.llm import LLMConfig, LLMStrategyType
|
||||
from .llm.llm_client import AIWrapper
|
||||
from .memory.agent_memory import AgentMemory
|
||||
from .memory.gpts.base import GptsMessage
|
||||
from .memory.gpts.gpts_memory import GptsMemory
|
||||
from .role import Role
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,26 +34,16 @@ class ConversableAgent(Role, Agent):
|
||||
actions: List[Action] = Field(default_factory=list)
|
||||
resources: List[AgentResource] = Field(default_factory=list)
|
||||
llm_config: Optional[LLMConfig] = None
|
||||
memory: GptsMemory = Field(default_factory=GptsMemory)
|
||||
resource_loader: Optional[ResourceLoader] = None
|
||||
max_retry_count: int = 3
|
||||
consecutive_auto_reply_counter: int = 0
|
||||
llm_client: Optional[AIWrapper] = None
|
||||
oai_system_message: List[Dict] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new agent."""
|
||||
Role.__init__(self, **kwargs)
|
||||
Agent.__init__(self)
|
||||
|
||||
def init_system_message(self) -> None:
|
||||
"""Initialize the system message."""
|
||||
content = self.prompt_template()
|
||||
# TODO: Don't modify the original data, need to be optimized
|
||||
self.oai_system_message = [
|
||||
{"content": content, "role": ModelMessageRoleType.SYSTEM}
|
||||
]
|
||||
|
||||
def check_available(self) -> None:
|
||||
"""Check if the agent is available.
|
||||
|
||||
@@ -63,7 +54,7 @@ class ConversableAgent(Role, Agent):
|
||||
# check run context
|
||||
if self.agent_context is None:
|
||||
raise ValueError(
|
||||
f"{self.name}[{self.profile}] Missing context in which agent is "
|
||||
f"{self.name}[{self.role}] Missing context in which agent is "
|
||||
f"running!"
|
||||
)
|
||||
|
||||
@@ -90,20 +81,20 @@ class ConversableAgent(Role, Agent):
|
||||
and action.resource_need not in have_resource_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.name}[{self.profile}] Missing resources required for "
|
||||
f"{self.name}[{self.role}] Missing resources required for "
|
||||
"runtime!"
|
||||
)
|
||||
else:
|
||||
if not self.is_human and not self.is_team:
|
||||
raise ValueError(
|
||||
f"This agent {self.name}[{self.profile}] is missing action modules."
|
||||
f"This agent {self.name}[{self.role}] is missing action modules."
|
||||
)
|
||||
# llm check
|
||||
if not self.is_human and (
|
||||
self.llm_config is None or self.llm_config.llm_client is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.name}[{self.profile}] Model configuration is missing or model "
|
||||
f"{self.name}[{self.role}] Model configuration is missing or model "
|
||||
"service is unavailable!"
|
||||
)
|
||||
|
||||
@@ -161,14 +152,19 @@ class ConversableAgent(Role, Agent):
|
||||
for action in self.actions:
|
||||
action.init_resource_loader(self.resource_loader)
|
||||
|
||||
# Initialize system messages
|
||||
self.init_system_message()
|
||||
|
||||
# Initialize LLM Server
|
||||
if not self.is_human:
|
||||
if not self.llm_config or not self.llm_config.llm_client:
|
||||
raise ValueError("LLM client is not initialized!")
|
||||
self.llm_client = AIWrapper(llm_client=self.llm_config.llm_client)
|
||||
self.memory.initialize(
|
||||
self.name,
|
||||
self.llm_config.llm_client,
|
||||
importance_scorer=self.memory_importance_scorer,
|
||||
insight_extractor=self.memory_insight_extractor,
|
||||
)
|
||||
# Clone the memory structure
|
||||
self.memory = self.memory.structure_clone()
|
||||
return self
|
||||
|
||||
def bind(self, target: Any) -> "ConversableAgent":
|
||||
@@ -176,7 +172,7 @@ class ConversableAgent(Role, Agent):
|
||||
if isinstance(target, LLMConfig):
|
||||
self.llm_config = target
|
||||
elif isinstance(target, GptsMemory):
|
||||
self.memory = target
|
||||
raise ValueError("GptsMemory is not supported!")
|
||||
elif isinstance(target, AgentContext):
|
||||
self.agent_context = target
|
||||
elif isinstance(target, ResourceLoader):
|
||||
@@ -186,6 +182,8 @@ class ConversableAgent(Role, Agent):
|
||||
self.actions.extend(target)
|
||||
elif _is_list_of_type(target, AgentResource):
|
||||
self.resources = target
|
||||
elif isinstance(target, AgentMemory):
|
||||
self.memory = target
|
||||
return self
|
||||
|
||||
async def send(
|
||||
@@ -200,9 +198,9 @@ class ConversableAgent(Role, Agent):
|
||||
with root_tracer.start_span(
|
||||
"agent.send",
|
||||
metadata={
|
||||
"sender": self.get_name(),
|
||||
"recipient": recipient.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": self.name,
|
||||
"recipient": recipient.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"agent_message": message.to_dict(),
|
||||
"request_reply": request_reply,
|
||||
"is_recovery": is_recovery,
|
||||
@@ -230,9 +228,9 @@ class ConversableAgent(Role, Agent):
|
||||
with root_tracer.start_span(
|
||||
"agent.receive",
|
||||
metadata={
|
||||
"sender": sender.get_name(),
|
||||
"recipient": self.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"recipient": self.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"agent_message": message.to_dict(),
|
||||
"request_reply": request_reply,
|
||||
"silent": silent,
|
||||
@@ -271,14 +269,14 @@ class ConversableAgent(Role, Agent):
|
||||
root_span = root_tracer.start_span(
|
||||
"agent.generate_reply",
|
||||
metadata={
|
||||
"sender": sender.get_name(),
|
||||
"recipient": self.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"recipient": self.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"received_message": received_message.to_dict(),
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
"rely_messages": [msg.to_dict() for msg in rely_messages]
|
||||
if rely_messages
|
||||
else None,
|
||||
"rely_messages": (
|
||||
[msg.to_dict() for msg in rely_messages] if rely_messages else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -295,18 +293,6 @@ class ConversableAgent(Role, Agent):
|
||||
)
|
||||
span.metadata["reply_message"] = reply_message.to_dict()
|
||||
|
||||
with root_tracer.start_span(
|
||||
"agent.generate_reply._system_message_assembly",
|
||||
metadata={
|
||||
"reply_message": reply_message.to_dict(),
|
||||
},
|
||||
) as span:
|
||||
# assemble system message
|
||||
await self._system_message_assembly(
|
||||
received_message.content, reply_message.context
|
||||
)
|
||||
span.metadata["assembled_system_messages"] = self.oai_system_message
|
||||
|
||||
fail_reason = None
|
||||
current_retry_counter = 0
|
||||
is_success = True
|
||||
@@ -325,8 +311,11 @@ class ConversableAgent(Role, Agent):
|
||||
retry_message, self, reviewer, request_reply=False
|
||||
)
|
||||
|
||||
thinking_messages = self._load_thinking_messages(
|
||||
received_message, sender, rely_messages
|
||||
thinking_messages = await self._load_thinking_messages(
|
||||
received_message,
|
||||
sender,
|
||||
rely_messages,
|
||||
context=reply_message.get_dict_context(),
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"agent.generate_reply.thinking",
|
||||
@@ -345,7 +334,7 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
with root_tracer.start_span(
|
||||
"agent.generate_reply.review",
|
||||
metadata={"llm_reply": llm_reply, "censored": self.get_name()},
|
||||
metadata={"llm_reply": llm_reply, "censored": self.name},
|
||||
) as span:
|
||||
# 2.Review whether what is being done is legal
|
||||
approve, comments = await self.review(llm_reply, self)
|
||||
@@ -361,8 +350,8 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.generate_reply.act",
|
||||
metadata={
|
||||
"llm_reply": llm_reply,
|
||||
"sender": sender.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"act_extent_param": act_extent_param,
|
||||
},
|
||||
) as span:
|
||||
@@ -383,8 +372,8 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.generate_reply.verify",
|
||||
metadata={
|
||||
"llm_reply": llm_reply,
|
||||
"sender": sender.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
},
|
||||
) as span:
|
||||
# 4.Reply information verification
|
||||
@@ -394,6 +383,9 @@ class ConversableAgent(Role, Agent):
|
||||
is_success = check_pass
|
||||
span.metadata["check_pass"] = check_pass
|
||||
span.metadata["reason"] = reason
|
||||
|
||||
question: str = received_message.content or ""
|
||||
ai_message: str = llm_reply or ""
|
||||
# 5.Optimize wrong answers myself
|
||||
if not check_pass:
|
||||
current_retry_counter += 1
|
||||
@@ -403,7 +395,20 @@ class ConversableAgent(Role, Agent):
|
||||
reply_message, sender, reviewer, request_reply=False
|
||||
)
|
||||
fail_reason = reason
|
||||
await self.save_to_memory(
|
||||
question=question,
|
||||
ai_message=ai_message,
|
||||
action_output=act_out,
|
||||
check_pass=check_pass,
|
||||
check_fail_reason=fail_reason,
|
||||
)
|
||||
else:
|
||||
await self.save_to_memory(
|
||||
question=question,
|
||||
ai_message=ai_message,
|
||||
action_output=act_out,
|
||||
check_pass=check_pass,
|
||||
)
|
||||
break
|
||||
reply_message.success = is_success
|
||||
return reply_message
|
||||
@@ -437,8 +442,6 @@ class ConversableAgent(Role, Agent):
|
||||
try:
|
||||
if prompt:
|
||||
llm_messages = _new_system_message(prompt) + llm_messages
|
||||
else:
|
||||
llm_messages = self.oai_system_message + llm_messages
|
||||
|
||||
if not self.llm_client:
|
||||
raise ValueError("LLM client is not initialized!")
|
||||
@@ -491,9 +494,9 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.act.run",
|
||||
metadata={
|
||||
"message": message,
|
||||
"sender": sender.get_name() if sender else None,
|
||||
"recipient": self.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": sender.name if sender else None,
|
||||
"recipient": self.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"need_resource": need_resource.to_dict() if need_resource else None,
|
||||
"rely_action_out": last_out.to_dict() if last_out else None,
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
@@ -563,9 +566,9 @@ class ConversableAgent(Role, Agent):
|
||||
"agent.initiate_chat",
|
||||
span_type=SpanType.AGENT,
|
||||
metadata={
|
||||
"sender": self.get_name(),
|
||||
"recipient": recipient.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"sender": self.name,
|
||||
"recipient": recipient.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"agent_message": agent_message.to_dict(),
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
},
|
||||
@@ -612,21 +615,27 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
gpts_message: GptsMessage = GptsMessage(
|
||||
conv_id=self.not_null_agent_context.conv_id,
|
||||
sender=sender.get_profile(),
|
||||
receiver=self.profile,
|
||||
sender=sender.role,
|
||||
receiver=self.role,
|
||||
role=role,
|
||||
rounds=self.consecutive_auto_reply_counter,
|
||||
current_goal=oai_message.get("current_goal", None),
|
||||
content=oai_message.get("content", None),
|
||||
context=json.dumps(oai_message["context"], ensure_ascii=False)
|
||||
if "context" in oai_message
|
||||
else None,
|
||||
review_info=json.dumps(oai_message["review_info"], ensure_ascii=False)
|
||||
if "review_info" in oai_message
|
||||
else None,
|
||||
action_report=json.dumps(oai_message["action_report"], ensure_ascii=False)
|
||||
if "action_report" in oai_message
|
||||
else None,
|
||||
context=(
|
||||
json.dumps(oai_message["context"], ensure_ascii=False)
|
||||
if "context" in oai_message
|
||||
else None
|
||||
),
|
||||
review_info=(
|
||||
json.dumps(oai_message["review_info"], ensure_ascii=False)
|
||||
if "review_info" in oai_message
|
||||
else None
|
||||
),
|
||||
action_report=(
|
||||
json.dumps(oai_message["action_report"], ensure_ascii=False)
|
||||
if "action_report" in oai_message
|
||||
else None
|
||||
),
|
||||
model_name=oai_message.get("model_name", None),
|
||||
)
|
||||
|
||||
@@ -643,10 +652,10 @@ class ConversableAgent(Role, Agent):
|
||||
def _print_received_message(self, message: AgentMessage, sender: Agent):
|
||||
# print the message received
|
||||
print("\n", "-" * 80, flush=True, sep="")
|
||||
_print_name = self.name if self.name else self.profile
|
||||
_print_name = self.name if self.name else self.role
|
||||
print(
|
||||
colored(
|
||||
sender.get_name() if sender.get_name() else sender.get_profile(),
|
||||
sender.name if sender.name else sender.role,
|
||||
"yellow",
|
||||
),
|
||||
"(to",
|
||||
@@ -660,7 +669,7 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
review_info = message.review_info
|
||||
if review_info:
|
||||
name = sender.get_name() if sender.get_name() else sender.get_profile()
|
||||
name = sender.name if sender.name else sender.role
|
||||
pass_msg = "Pass" if review_info.approve else "Reject"
|
||||
review_msg = f"{pass_msg}({review_info.comments})"
|
||||
approve_print = f">>>>>>>>{name} Review info: \n{review_msg}"
|
||||
@@ -668,7 +677,7 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
action_report = message.action_report
|
||||
if action_report:
|
||||
name = sender.get_name() if sender.get_name() else sender.get_profile()
|
||||
name = sender.name if sender.name else sender.role
|
||||
action_msg = (
|
||||
"execution succeeded"
|
||||
if action_report["is_exe_success"]
|
||||
@@ -690,42 +699,32 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
self._print_received_message(message, sender)
|
||||
|
||||
async def _system_message_assembly(
|
||||
self, question: Optional[str], context: Optional[Union[str, Dict]] = None
|
||||
):
|
||||
# system message
|
||||
self.init_system_message()
|
||||
if len(self.oai_system_message) > 0:
|
||||
resource_prompt_list = []
|
||||
for item in self.resources:
|
||||
resource_client = self.not_null_resource_loader.get_resource_api(
|
||||
item.type, ResourceClient
|
||||
async def generate_resource_variables(
|
||||
self, question: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate the resource variables."""
|
||||
resource_prompt_list = []
|
||||
for item in self.resources:
|
||||
resource_client = self.not_null_resource_loader.get_resource_api(
|
||||
item.type, ResourceClient
|
||||
)
|
||||
if not resource_client:
|
||||
raise ValueError(
|
||||
f"Resource {item.type}:{item.value} missing resource loader"
|
||||
f" implementation,unable to read resources!"
|
||||
)
|
||||
if not resource_client:
|
||||
raise ValueError(
|
||||
f"Resource {item.type}:{item.value} missing resource loader"
|
||||
f" implementation,unable to read resources!"
|
||||
)
|
||||
resource_prompt_list.append(
|
||||
await resource_client.get_resource_prompt(item, question)
|
||||
)
|
||||
if context is None or not isinstance(context, dict):
|
||||
context = {}
|
||||
resource_prompt_list.append(
|
||||
await resource_client.get_resource_prompt(item, question)
|
||||
)
|
||||
|
||||
resource_prompt = ""
|
||||
if len(resource_prompt_list) > 0:
|
||||
resource_prompt = "RESOURCES:" + "\n".join(resource_prompt_list)
|
||||
resource_prompt = ""
|
||||
if len(resource_prompt_list) > 0:
|
||||
resource_prompt = "RESOURCES:" + "\n".join(resource_prompt_list)
|
||||
|
||||
out_schema: Optional[str] = ""
|
||||
if self.actions and len(self.actions) > 0:
|
||||
out_schema = self.actions[0].ai_out_schema
|
||||
for message in self.oai_system_message:
|
||||
new_content = message["content"].format(
|
||||
resource_prompt=resource_prompt,
|
||||
out_schema=out_schema,
|
||||
**context,
|
||||
)
|
||||
message["content"] = new_content
|
||||
out_schema: Optional[str] = ""
|
||||
if self.actions and len(self.actions) > 0:
|
||||
out_schema = self.actions[0].ai_out_schema
|
||||
return {"resource_prompt": resource_prompt, "out_schema": out_schema}
|
||||
|
||||
def _excluded_models(
|
||||
self,
|
||||
@@ -774,7 +773,7 @@ class ConversableAgent(Role, Agent):
|
||||
else:
|
||||
raise ValueError("No model service available!")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.profile} get next llm failed!{str(e)}")
|
||||
logger.error(f"{self.role} get next llm failed!{str(e)}")
|
||||
raise ValueError(f"Failed to allocate model service,{str(e)}!")
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
@@ -803,9 +802,9 @@ class ConversableAgent(Role, Agent):
|
||||
if item.role:
|
||||
role = item.role
|
||||
else:
|
||||
if item.receiver == self.profile:
|
||||
if item.receiver == self.role:
|
||||
role = ModelMessageRoleType.HUMAN
|
||||
elif item.sender == self.profile:
|
||||
elif item.sender == self.role:
|
||||
role = ModelMessageRoleType.AI
|
||||
else:
|
||||
continue
|
||||
@@ -825,14 +824,80 @@ class ConversableAgent(Role, Agent):
|
||||
AgentMessage(
|
||||
content=content,
|
||||
role=role,
|
||||
context=json.loads(item.context)
|
||||
if item.context is not None
|
||||
else None,
|
||||
context=(
|
||||
json.loads(item.context) if item.context is not None else None
|
||||
),
|
||||
)
|
||||
)
|
||||
return oai_messages
|
||||
|
||||
def _load_thinking_messages(
|
||||
async def _load_thinking_messages(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[AgentMessage]:
|
||||
observation = received_message.content
|
||||
if not observation:
|
||||
raise ValueError("The received message content is empty!")
|
||||
memories = await self.read_memories(observation)
|
||||
reply_message_str = ""
|
||||
if context is None:
|
||||
context = {}
|
||||
if rely_messages:
|
||||
copied_rely_messages = [m.copy() for m in rely_messages]
|
||||
# When directly relying on historical messages, use the execution result
|
||||
# content as a dependency
|
||||
for message in copied_rely_messages:
|
||||
action_report: Optional[ActionOutput] = ActionOutput.from_dict(
|
||||
message.action_report
|
||||
)
|
||||
if action_report:
|
||||
# TODO: Modify in-place, need to be optimized
|
||||
message.content = action_report.content
|
||||
if message.name != self.role:
|
||||
# TODO, use name
|
||||
# Rely messages are not from the current agent
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
reply_message_str += f"Question: {message.content}\n"
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
reply_message_str += f"Observation: {message.content}\n"
|
||||
if reply_message_str:
|
||||
memories += "\n" + reply_message_str
|
||||
|
||||
system_prompt = await self.build_prompt(
|
||||
question=observation,
|
||||
is_system=True,
|
||||
most_recent_memories=memories,
|
||||
**context,
|
||||
)
|
||||
user_prompt = await self.build_prompt(
|
||||
question=observation,
|
||||
is_system=False,
|
||||
most_recent_memories=memories,
|
||||
**context,
|
||||
)
|
||||
|
||||
agent_messages = []
|
||||
if system_prompt:
|
||||
agent_messages.append(
|
||||
AgentMessage(
|
||||
content=system_prompt,
|
||||
role=ModelMessageRoleType.SYSTEM,
|
||||
)
|
||||
)
|
||||
if user_prompt:
|
||||
agent_messages.append(
|
||||
AgentMessage(
|
||||
content=user_prompt,
|
||||
role=ModelMessageRoleType.HUMAN,
|
||||
)
|
||||
)
|
||||
|
||||
return agent_messages
|
||||
|
||||
def _old_load_thinking_messages(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
@@ -846,8 +911,8 @@ class ConversableAgent(Role, Agent):
|
||||
with root_tracer.start_span(
|
||||
"agent._load_thinking_messages",
|
||||
metadata={
|
||||
"sender": sender.get_name(),
|
||||
"recipient": self.get_name(),
|
||||
"sender": sender.name,
|
||||
"recipient": self.name,
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
"current_goal": current_goal,
|
||||
},
|
||||
@@ -855,8 +920,8 @@ class ConversableAgent(Role, Agent):
|
||||
# Get historical information from the memory
|
||||
memory_messages = self.memory.message_memory.get_between_agents(
|
||||
self.not_null_agent_context.conv_id,
|
||||
self.profile,
|
||||
sender.get_profile(),
|
||||
self.role,
|
||||
sender.role,
|
||||
current_goal,
|
||||
)
|
||||
span.metadata["memory_messages"] = [
|
||||
|
@@ -1,13 +1,14 @@
|
||||
"""Base classes for managing a group of agents in a team chat."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from .action.base import ActionOutput
|
||||
from .agent import Agent, AgentMessage
|
||||
from .base_agent import ConversableAgent
|
||||
from .profile import ProfileConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,7 +87,7 @@ class Team(BaseModel):
|
||||
@property
|
||||
def agent_names(self) -> List[str]:
|
||||
"""Return the names of the agents in the group chat."""
|
||||
return [agent.get_profile() for agent in self.agents]
|
||||
return [agent.role for agent in self.agents]
|
||||
|
||||
def agent_by_name(self, name: str) -> Agent:
|
||||
"""Return the agent with a given name."""
|
||||
@@ -121,10 +122,14 @@ class ManagerAgent(ConversableAgent, Team):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = "TeamManager"
|
||||
goal: str = "manage all hired intelligent agents to complete mission objectives"
|
||||
constraints: List[str] = []
|
||||
desc: str = goal
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="ManagerAgent",
|
||||
profile="TeamManager",
|
||||
goal="manage all hired intelligent agents to complete mission objectives",
|
||||
constraints=[],
|
||||
desc="manage all hired intelligent agents to complete mission objectives",
|
||||
)
|
||||
|
||||
is_team: bool = True
|
||||
|
||||
# The management agent does not need to retry the exception. The actual execution
|
||||
@@ -149,6 +154,16 @@ class ManagerAgent(ConversableAgent, Team):
|
||||
self.messages.append(message.to_llm_message())
|
||||
return message.content, None
|
||||
|
||||
async def _load_thinking_messages(
|
||||
self,
|
||||
received_message: AgentMessage,
|
||||
sender: Agent,
|
||||
rely_messages: Optional[List[AgentMessage]] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[AgentMessage]:
|
||||
"""Load messages for thinking."""
|
||||
return [AgentMessage(content=received_message.content)]
|
||||
|
||||
async def act(
|
||||
self,
|
||||
message: Optional[str],
|
||||
|
@@ -1 +0,0 @@
|
||||
"""LLM for agents."""
|
@@ -1,113 +0,0 @@
|
||||
"""LLM module."""
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import LLMClient, ModelMetadata, ModelRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_model_request(input_value: Dict) -> ModelRequest:
|
||||
"""Build model request from input value.
|
||||
|
||||
Args:
|
||||
input_value(str or dict): input value
|
||||
|
||||
Returns:
|
||||
ModelRequest: model request, pass to llm client
|
||||
"""
|
||||
parm = {
|
||||
"model": input_value.get("model"),
|
||||
"messages": input_value.get("messages"),
|
||||
"temperature": input_value.get("temperature", None),
|
||||
"max_new_tokens": input_value.get("max_new_tokens", None),
|
||||
"stop": input_value.get("stop", None),
|
||||
"stop_token_ids": input_value.get("stop_token_ids", None),
|
||||
"context_len": input_value.get("context_len", None),
|
||||
"echo": input_value.get("echo", None),
|
||||
"span_id": input_value.get("span_id", None),
|
||||
}
|
||||
|
||||
return ModelRequest(**parm)
|
||||
|
||||
|
||||
class LLMStrategyType(Enum):
|
||||
"""LLM strategy type."""
|
||||
|
||||
Priority = "priority"
|
||||
Auto = "auto"
|
||||
Default = "default"
|
||||
|
||||
|
||||
class LLMStrategy:
|
||||
"""LLM strategy base class."""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, context: Optional[str] = None):
|
||||
"""Create an LLMStrategy instance."""
|
||||
self._llm_client = llm_client
|
||||
self._context = context
|
||||
|
||||
@property
|
||||
def type(self) -> LLMStrategyType:
|
||||
"""Return the strategy type."""
|
||||
return LLMStrategyType.Default
|
||||
|
||||
def _excluded_models(
|
||||
self,
|
||||
all_models: List[ModelMetadata],
|
||||
excluded_models: List[str],
|
||||
need_uses: Optional[List[str]] = None,
|
||||
):
|
||||
if not need_uses:
|
||||
need_uses = []
|
||||
can_uses = []
|
||||
for item in all_models:
|
||||
if item.model in need_uses and item.model not in excluded_models:
|
||||
can_uses.append(item)
|
||||
return can_uses
|
||||
|
||||
async def next_llm(self, excluded_models: Optional[List[str]] = None):
|
||||
"""Return next available llm model name.
|
||||
|
||||
Args:
|
||||
excluded_models(List[str]): excluded models
|
||||
|
||||
Returns:
|
||||
str: Next available llm model name
|
||||
"""
|
||||
if not excluded_models:
|
||||
excluded_models = []
|
||||
try:
|
||||
all_models = await self._llm_client.models()
|
||||
available_llms = self._excluded_models(all_models, excluded_models, None)
|
||||
if available_llms and len(available_llms) > 0:
|
||||
return available_llms[0].model
|
||||
else:
|
||||
raise ValueError("No model service available!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.type} get next llm failed!{str(e)}")
|
||||
raise ValueError(f"Failed to allocate model service,{str(e)}!")
|
||||
|
||||
|
||||
llm_strategies: Dict[LLMStrategyType, List[Type[LLMStrategy]]] = defaultdict(list)
|
||||
|
||||
|
||||
def register_llm_strategy(
|
||||
llm_strategy_type: LLMStrategyType, strategy: Type[LLMStrategy]
|
||||
):
|
||||
"""Register llm strategy."""
|
||||
llm_strategies[llm_strategy_type].append(strategy)
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""LLM configuration."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
llm_client: Optional[LLMClient] = Field(default_factory=LLMClient)
|
||||
llm_strategy: LLMStrategyType = Field(default=LLMStrategyType.Default)
|
||||
strategy_context: Optional[Any] = None
|
@@ -1,183 +0,0 @@
|
||||
"""AIWrapper for LLM."""
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
from dbgpt.util.error_types import LLMChatError
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from .llm import _build_model_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIWrapper:
|
||||
"""AIWrapper for LLM."""
|
||||
|
||||
cache_path_root: str = ".cache"
|
||||
extra_kwargs = {
|
||||
"cache_seed",
|
||||
"filter_func",
|
||||
"allow_format_str_template",
|
||||
"context",
|
||||
"llm_model",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, llm_client: LLMClient, output_parser: Optional[BaseOutputParser] = None
|
||||
):
|
||||
"""Create an AIWrapper instance."""
|
||||
self.llm_echo = False
|
||||
self.model_cache_enable = False
|
||||
self._llm_client = llm_client
|
||||
self._output_parser = output_parser or BaseOutputParser(is_stream_out=False)
|
||||
|
||||
@classmethod
|
||||
def instantiate(
|
||||
cls,
|
||||
template: Optional[Union[str, Callable]] = None,
|
||||
context: Optional[Dict] = None,
|
||||
allow_format_str_template: Optional[bool] = False,
|
||||
):
|
||||
"""Instantiate the template with the context."""
|
||||
if not context or template is None:
|
||||
return template
|
||||
if isinstance(template, str):
|
||||
return template.format(**context) if allow_format_str_template else template
|
||||
return template(context)
|
||||
|
||||
def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
|
||||
"""Prime the create_config with additional_kwargs."""
|
||||
# Validate the config
|
||||
prompt = create_config.get("prompt")
|
||||
messages = create_config.get("messages")
|
||||
if prompt is None and messages is None:
|
||||
raise ValueError(
|
||||
"Either prompt or messages should be in create config but not both."
|
||||
)
|
||||
|
||||
context = extra_kwargs.get("context")
|
||||
if context is None:
|
||||
# No need to instantiate if no context is provided.
|
||||
return create_config
|
||||
# Instantiate the prompt or messages
|
||||
allow_format_str_template = extra_kwargs.get("allow_format_str_template", False)
|
||||
# Make a copy of the config
|
||||
params = create_config.copy()
|
||||
if prompt is not None:
|
||||
# Instantiate the prompt
|
||||
params["prompt"] = self.instantiate(
|
||||
prompt, context, allow_format_str_template
|
||||
)
|
||||
elif context and messages and isinstance(messages, list):
|
||||
# Instantiate the messages
|
||||
params["messages"] = [
|
||||
{
|
||||
**m,
|
||||
"content": self.instantiate(
|
||||
m["content"], context, allow_format_str_template
|
||||
),
|
||||
}
|
||||
if m.get("content")
|
||||
else m
|
||||
for m in messages
|
||||
]
|
||||
return params
|
||||
|
||||
def _separate_create_config(self, config):
|
||||
"""Separate the config into create_config and extra_kwargs."""
|
||||
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
|
||||
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
|
||||
return create_config, extra_kwargs
|
||||
|
||||
def _get_key(self, config):
|
||||
"""Get a unique identifier of a configuration.
|
||||
|
||||
Args:
|
||||
config (dict or list): A configuration.
|
||||
|
||||
Returns:
|
||||
tuple: A unique identifier which can be used as a key for a dict.
|
||||
"""
|
||||
non_cache_key = ["api_key", "base_url", "api_type", "api_version"]
|
||||
copied = False
|
||||
for key in non_cache_key:
|
||||
if key in config:
|
||||
config, copied = config.copy() if not copied else config, True
|
||||
config.pop(key)
|
||||
return json.dumps(config, sort_keys=True, ensure_ascii=False)
|
||||
|
||||
async def create(self, **config) -> Optional[str]:
|
||||
"""Create a response from the input config."""
|
||||
# merge the input config with the i-th config in the config list
|
||||
full_config = {**config}
|
||||
# separate the config into create_config and extra_kwargs
|
||||
create_config, extra_kwargs = self._separate_create_config(full_config)
|
||||
|
||||
# construct the create params
|
||||
params = self._construct_create_params(create_config, extra_kwargs)
|
||||
filter_func = extra_kwargs.get("filter_func")
|
||||
context = extra_kwargs.get("context")
|
||||
llm_model = extra_kwargs.get("llm_model")
|
||||
try:
|
||||
response = await self._completions_create(llm_model, params)
|
||||
except LLMChatError as e:
|
||||
logger.debug(f"{llm_model} generate failed!{str(e)}")
|
||||
raise e
|
||||
else:
|
||||
pass_filter = filter_func is None or filter_func(
|
||||
context=context, response=response
|
||||
)
|
||||
if pass_filter:
|
||||
# Return the response if it passes the filter
|
||||
return response
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
|
||||
metadata["messages"] = list(
|
||||
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _llm_messages_convert(self, params):
|
||||
gpts_messages = params["messages"]
|
||||
# TODO
|
||||
|
||||
return gpts_messages
|
||||
|
||||
async def _completions_create(self, llm_model, params) -> str:
|
||||
payload = {
|
||||
"model": llm_model,
|
||||
"prompt": params.get("prompt"),
|
||||
"messages": self._llm_messages_convert(params),
|
||||
"temperature": float(params.get("temperature")),
|
||||
"max_new_tokens": int(params.get("max_new_tokens")),
|
||||
"echo": self.llm_echo,
|
||||
}
|
||||
logger.info(f"Request: \n{payload}")
|
||||
span = root_tracer.start_span(
|
||||
"Agent.llm_client.no_streaming_call",
|
||||
metadata=self._get_span_metadata(payload),
|
||||
)
|
||||
payload["span_id"] = span.span_id
|
||||
payload["model_cache_enable"] = self.model_cache_enable
|
||||
try:
|
||||
model_request = _build_model_request(payload)
|
||||
model_output = await self._llm_client.generate(model_request.copy())
|
||||
parsed_output = self._output_parser.parse_model_nostream_resp(
|
||||
model_output, "#########################"
|
||||
)
|
||||
return parsed_output
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Call LLMClient error, {str(e)}, detail: {traceback.format_exc()}"
|
||||
)
|
||||
raise LLMChatError(original_exception=e) from e
|
||||
finally:
|
||||
span.end()
|
@@ -1 +0,0 @@
|
||||
"""LLM strategy module."""
|
@@ -1,37 +0,0 @@
|
||||
"""Priority strategy for LLM."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from ..llm import LLMStrategy, LLMStrategyType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMStrategyPriority(LLMStrategy):
|
||||
"""Priority strategy for llm model service."""
|
||||
|
||||
@property
|
||||
def type(self) -> LLMStrategyType:
|
||||
"""Return the strategy type."""
|
||||
return LLMStrategyType.Priority
|
||||
|
||||
async def next_llm(self, excluded_models: Optional[List[str]] = None) -> str:
|
||||
"""Return next available llm model name."""
|
||||
try:
|
||||
if not excluded_models:
|
||||
excluded_models = []
|
||||
all_models = await self._llm_client.models()
|
||||
if not self._context:
|
||||
raise ValueError("No context provided for priority strategy!")
|
||||
priority: List[str] = json.loads(self._context)
|
||||
can_uses = self._excluded_models(all_models, excluded_models, priority)
|
||||
if can_uses and len(can_uses) > 0:
|
||||
return can_uses[0].model
|
||||
else:
|
||||
raise ValueError("No model service available!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.type} get next llm failed!{str(e)}")
|
||||
raise ValueError(f"Failed to allocate model service,{str(e)}!")
|
16
dbgpt/agent/core/memory/__init__.py
Normal file
16
dbgpt/agent/core/memory/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Memory module for the agent."""
|
||||
|
||||
from .agent_memory import AgentMemory, AgentMemoryFragment # noqa: F401
|
||||
from .base import ( # noqa: F401
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
InsightMemoryFragment,
|
||||
Memory,
|
||||
MemoryFragment,
|
||||
SensoryMemory,
|
||||
ShortTermMemory,
|
||||
)
|
||||
from .hybrid import HybridMemory # noqa: F401
|
||||
from .llm import LLMImportanceScorer, LLMInsightExtractor # noqa: F401
|
||||
from .long_term import LongTermMemory, LongTermRetriever # noqa: F401
|
||||
from .short_term import EnhancedShortTermMemory # noqa: F401
|
282
dbgpt/agent/core/memory/agent_memory.py
Normal file
282
dbgpt/agent/core/memory/agent_memory.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""Agent memory module."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Callable, List, Optional, Type, cast
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.id_generator import new_id
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
Memory,
|
||||
MemoryFragment,
|
||||
ShortTermMemory,
|
||||
WriteOperation,
|
||||
)
|
||||
from .gpts import GptsMemory, GptsMessageMemory, GptsPlansMemory
|
||||
|
||||
|
||||
class AgentMemoryFragment(MemoryFragment):
|
||||
"""Default memory fragment for agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
is_insight: bool = False,
|
||||
):
|
||||
"""Create a memory fragment."""
|
||||
if not memory_id:
|
||||
# Generate a new memory id, we use snowflake id generator here.
|
||||
memory_id = new_id()
|
||||
self.observation = observation
|
||||
self._embeddings = embeddings
|
||||
self.memory_id: int = cast(int, memory_id)
|
||||
self._importance: Optional[float] = importance
|
||||
self._last_accessed_time: Optional[datetime] = last_accessed_time
|
||||
self._is_insight = is_insight
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
"""Return the memory id."""
|
||||
return self.memory_id
|
||||
|
||||
@property
|
||||
def raw_observation(self) -> str:
|
||||
"""Return the raw observation."""
|
||||
return self.observation
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[List[float]]:
|
||||
"""Return the embeddings of the memory fragment."""
|
||||
return self._embeddings
|
||||
|
||||
def update_embeddings(self, embeddings: List[float]) -> None:
|
||||
"""Update the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embeddings(List[float]): embeddings
|
||||
"""
|
||||
self._embeddings = embeddings
|
||||
|
||||
def calculate_current_embeddings(
|
||||
self, embedding_func: Callable[[List[str]], List[List[float]]]
|
||||
) -> List[float]:
|
||||
"""Calculate the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embedding_func(Callable[[List[str]], List[List[float]]]): Function to
|
||||
compute embeddings
|
||||
|
||||
Returns:
|
||||
List[float]: Embeddings of the memory fragment
|
||||
"""
|
||||
embeddings = embedding_func([self.observation])
|
||||
return embeddings[0]
|
||||
|
||||
@property
|
||||
def is_insight(self) -> bool:
|
||||
"""Return whether the memory fragment is an insight.
|
||||
|
||||
Returns:
|
||||
bool: Whether the memory fragment is an insight
|
||||
"""
|
||||
return self._is_insight
|
||||
|
||||
@property
|
||||
def importance(self) -> Optional[float]:
|
||||
"""Return the importance of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[float]: Importance of the memory fragment
|
||||
"""
|
||||
return self._importance
|
||||
|
||||
def update_importance(self, importance: float) -> Optional[float]:
|
||||
"""Update the importance of the memory fragment.
|
||||
|
||||
Args:
|
||||
importance(float): Importance of the memory fragment
|
||||
|
||||
Returns:
|
||||
Optional[float]: Old importance
|
||||
"""
|
||||
old_importance = self._importance
|
||||
self._importance = importance
|
||||
return old_importance
|
||||
|
||||
@property
|
||||
def last_accessed_time(self) -> Optional[datetime]:
|
||||
"""Return the last accessed time of the memory fragment.
|
||||
|
||||
Used to determine the least recently used memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: Last accessed time
|
||||
"""
|
||||
return self._last_accessed_time
|
||||
|
||||
def update_accessed_time(self, now: datetime) -> Optional[datetime]:
|
||||
"""Update the last accessed time of the memory fragment.
|
||||
|
||||
Args:
|
||||
now(datetime): Current time
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: Old last accessed time
|
||||
"""
|
||||
old_time = self._last_accessed_time
|
||||
self._last_accessed_time = now
|
||||
return old_time
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
cls: Type["AgentMemoryFragment"],
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
is_insight: bool = False,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
**kwargs
|
||||
) -> "AgentMemoryFragment":
|
||||
"""Build a memory fragment from the given parameters."""
|
||||
return cls(
|
||||
observation=observation,
|
||||
embeddings=embeddings,
|
||||
memory_id=memory_id,
|
||||
importance=importance,
|
||||
last_accessed_time=last_accessed_time,
|
||||
is_insight=is_insight,
|
||||
)
|
||||
|
||||
def copy(self: "AgentMemoryFragment") -> "AgentMemoryFragment":
|
||||
"""Return a copy of the memory fragment."""
|
||||
return AgentMemoryFragment.build_from(
|
||||
observation=self.observation,
|
||||
embeddings=self._embeddings,
|
||||
memory_id=self.memory_id,
|
||||
importance=self.importance,
|
||||
last_accessed_time=self.last_accessed_time,
|
||||
is_insight=self.is_insight,
|
||||
)
|
||||
|
||||
|
||||
class AgentMemory(Memory[AgentMemoryFragment]):
|
||||
"""Agent memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory: Optional[Memory[AgentMemoryFragment]] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[AgentMemoryFragment]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[AgentMemoryFragment]] = None,
|
||||
gpts_memory: Optional[GptsMemory] = None,
|
||||
):
|
||||
"""Create an agent memory.
|
||||
|
||||
Args:
|
||||
memory(Memory[AgentMemoryFragment]): Memory to store fragments
|
||||
importance_scorer(ImportanceScorer[AgentMemoryFragment]): Scorer to
|
||||
calculate the importance of memory fragments
|
||||
insight_extractor(InsightExtractor[AgentMemoryFragment]): Extractor to
|
||||
extract insights from memory fragments
|
||||
gpts_memory(GptsMemory): Memory to store GPTs related information
|
||||
"""
|
||||
if not memory:
|
||||
memory = ShortTermMemory(buffer_size=5)
|
||||
if not gpts_memory:
|
||||
gpts_memory = GptsMemory()
|
||||
self.memory: Memory[AgentMemoryFragment] = cast(
|
||||
Memory[AgentMemoryFragment], memory
|
||||
)
|
||||
self.importance_scorer = importance_scorer
|
||||
self.insight_extractor = insight_extractor
|
||||
self.gpts_memory = gpts_memory
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "AgentMemory", now: Optional[datetime] = None
|
||||
) -> "AgentMemory":
|
||||
"""Return a structure clone of the memory.
|
||||
|
||||
The gpst_memory is not cloned, it will be shared in whole agent memory.
|
||||
"""
|
||||
m = AgentMemory(
|
||||
memory=self.memory.structure_clone(now),
|
||||
importance_scorer=self.importance_scorer,
|
||||
insight_extractor=self.insight_extractor,
|
||||
gpts_memory=self.gpts_memory,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[AgentMemoryFragment]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[AgentMemoryFragment]] = None,
|
||||
real_memory_fragment_class: Optional[Type[AgentMemoryFragment]] = None,
|
||||
) -> None:
|
||||
"""Initialize the memory."""
|
||||
self.memory.initialize(
|
||||
name=name,
|
||||
llm_client=llm_client,
|
||||
importance_scorer=importance_scorer or self.importance_scorer,
|
||||
insight_extractor=insight_extractor or self.insight_extractor,
|
||||
real_memory_fragment_class=real_memory_fragment_class
|
||||
or AgentMemoryFragment,
|
||||
)
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: AgentMemoryFragment,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[AgentMemoryFragment]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
return await self.memory.write(memory_fragment, now)
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[AgentMemoryFragment]:
|
||||
"""Read memory fragments related to the observation.
|
||||
|
||||
Args:
|
||||
observation(str): Observation
|
||||
alpha(float): Importance weight
|
||||
beta(float): Time weight
|
||||
gamma(float): Randomness weight
|
||||
|
||||
Returns:
|
||||
List[AgentMemoryFragment]: List of memory fragments
|
||||
"""
|
||||
return await self.memory.read(observation, alpha, beta, gamma)
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[AgentMemoryFragment]:
|
||||
"""Clear the memory."""
|
||||
return await self.memory.clear()
|
||||
|
||||
@property
|
||||
def plans_memory(self) -> GptsPlansMemory:
|
||||
"""Return the plan memory."""
|
||||
return self.gpts_memory.plans_memory
|
||||
|
||||
@property
|
||||
def message_memory(self) -> GptsMessageMemory:
|
||||
"""Return the message memory."""
|
||||
return self.gpts_memory.message_memory
|
776
dbgpt/agent/core/memory/base.py
Normal file
776
dbgpt/agent/core/memory/base.py
Normal file
@@ -0,0 +1,776 @@
|
||||
"""Memory for agent.
|
||||
|
||||
Human memory follows a general progression from sensory memory that registers
|
||||
perceptual inputs, to short-term memory that maintains information transiently, to
|
||||
long-term memory that consolidates information over extended periods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import PublicAPI, immutable, mutable
|
||||
|
||||
T = TypeVar("T", bound="MemoryFragment")
|
||||
M = TypeVar("M", bound="Memory")
|
||||
|
||||
|
||||
class WriteOperation(str, Enum):
|
||||
"""Write operation."""
|
||||
|
||||
ADD = "add"
|
||||
RETRIEVAL = "retrieval"
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class MemoryFragment(ABC):
|
||||
"""Memory fragment interface.
|
||||
|
||||
It is the interface of memory fragment, which is the basic unit of memory, which
|
||||
contains the basic information of memory, such as observation, importance, whether
|
||||
it is insight, last access time, etc
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def build_from(
|
||||
cls: Type[T],
|
||||
observation: str,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
memory_id: Optional[int] = None,
|
||||
importance: Optional[float] = None,
|
||||
is_insight: bool = False,
|
||||
last_accessed_time: Optional[datetime] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Build a memory fragment from memory id and observation.
|
||||
|
||||
Args:
|
||||
observation(str): Observation
|
||||
embeddings(List[float], optional): Embeddings of the memory fragment.
|
||||
memory_id(int): Memory id
|
||||
importance(float): Importance
|
||||
is_insight(bool): Whether the memory fragment is an insight
|
||||
last_accessed_time(datetime): Last accessed time
|
||||
|
||||
Returns:
|
||||
MemoryFragment: Memory fragment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def id(self) -> int:
|
||||
"""Return the id of the memory fragment.
|
||||
|
||||
Commonly, the id is generated by Snowflake algorithm. So we can parse the
|
||||
timestamp of when the memory fragment is created.
|
||||
|
||||
Returns:
|
||||
int: id
|
||||
"""
|
||||
|
||||
@property
|
||||
def metadata(self) -> Dict[str, Any]:
|
||||
"""Return the metadata of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Metadata
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def importance(self) -> Optional[float]:
|
||||
"""Return the importance of the memory fragment.
|
||||
|
||||
It should be noted that importance only reflects the characters of the memory
|
||||
itself.
|
||||
|
||||
Returns:
|
||||
Optional[float]: importance, None means the importance is not available.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def update_importance(self, importance: float) -> Optional[float]:
|
||||
"""Update the importance of the memory fragment.
|
||||
|
||||
Args:
|
||||
importance(float): importance
|
||||
|
||||
Returns:
|
||||
Optional[float]: importance
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def raw_observation(self) -> str:
|
||||
"""Return the raw observation.
|
||||
|
||||
Raw observation is the original observation data, it can be an observation from
|
||||
environment or an observation after executing an action.
|
||||
|
||||
Returns:
|
||||
str: raw observation
|
||||
"""
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[List[float]]:
|
||||
"""Return the embeddings of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[List[float]]: embeddings
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def update_embeddings(self, embeddings: List[float]) -> None:
|
||||
"""Update the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embeddings(List[float]): embeddings
|
||||
"""
|
||||
|
||||
def calculate_current_embeddings(
|
||||
self, embedding_func: Callable[[List[str]], List[List[float]]]
|
||||
) -> List[float]:
|
||||
"""Calculate the embeddings of the memory fragment.
|
||||
|
||||
Args:
|
||||
embedding_func(Callable[[List[str]], List[List[float]]]): Function to
|
||||
compute embeddings
|
||||
|
||||
Returns:
|
||||
List[float]: Embeddings of the memory fragment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_insight(self) -> bool:
|
||||
"""Return whether the memory fragment is an insight.
|
||||
|
||||
Returns:
|
||||
bool: whether the memory fragment is an insight.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def last_accessed_time(self) -> Optional[datetime]:
|
||||
"""Return the last accessed time of the memory fragment.
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: last accessed time
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_accessed_time(self, now: datetime) -> Optional[datetime]:
|
||||
"""Update the last accessed time of the memory fragment.
|
||||
|
||||
Args:
|
||||
now(datetime): The current time
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: The last accessed time
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self: T) -> T:
|
||||
"""Copy the memory fragment."""
|
||||
|
||||
def reduce(self, memory_fragments: List[T], **kwargs) -> T:
|
||||
"""Reduce memory fragments to a single memory fragment.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
T: The reduced memory fragment
|
||||
"""
|
||||
obs = []
|
||||
for memory_fragment in memory_fragments:
|
||||
obs.append(memory_fragment.raw_observation)
|
||||
new_observation = ";".join(obs)
|
||||
return self.current_class.build_from(new_observation, **kwargs) # type: ignore
|
||||
|
||||
@property
|
||||
def current_class(self: T) -> Type[T]:
|
||||
"""Return the current class."""
|
||||
return self.__class__
|
||||
|
||||
|
||||
class InsightMemoryFragment(Generic[T]):
|
||||
"""Insight memory fragment.
|
||||
|
||||
Insight memory fragment is a memory fragment that contains insights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_memory_fragment: Union[T, List[T]],
|
||||
insights: Union[List[T], List[str]],
|
||||
):
|
||||
"""Create an insight memory fragment.
|
||||
|
||||
Insight is also a memory fragment.
|
||||
"""
|
||||
if insights and isinstance(insights[0], str):
|
||||
mf = (
|
||||
original_memory_fragment[0]
|
||||
if isinstance(original_memory_fragment, list)
|
||||
else original_memory_fragment
|
||||
)
|
||||
insights = [
|
||||
mf.current_class.build_from(i, is_insight=True) for i in insights # type: ignore # noqa
|
||||
]
|
||||
self._original_memory_fragment = original_memory_fragment
|
||||
self._insights: List[T] = cast(List[T], insights)
|
||||
|
||||
@property
|
||||
def original_memory_fragment(self) -> Union[T, List[T]]:
|
||||
"""Return the original memory fragment."""
|
||||
return self._original_memory_fragment
|
||||
|
||||
@property
|
||||
def insights(self) -> List[T]:
|
||||
"""Return the insights."""
|
||||
return self._insights
|
||||
|
||||
|
||||
class DiscardedMemoryFragments(Generic[T]):
|
||||
"""Discarded memory fragments.
|
||||
|
||||
Sometimes, we need to discard some memory fragments, there are following cases:
|
||||
1. Memory duplicated, the same/similar action is executed multiple times and the
|
||||
same/similar observation from environment is received.
|
||||
2. Memory overflow. The memory is full and the new memory fragment needs to be
|
||||
written.
|
||||
3. The memory fragment is not important enough.
|
||||
4. Simulation of forgetting mechanism.
|
||||
|
||||
The discarded memory fragments may be transferred to another memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
discarded_memory_fragments: List[T],
|
||||
discarded_insights: Optional[List[InsightMemoryFragment[T]]] = None,
|
||||
):
|
||||
"""Create a discarded memory fragments."""
|
||||
if discarded_insights is None:
|
||||
discarded_insights = []
|
||||
self._discarded_memory_fragments = discarded_memory_fragments
|
||||
self._discarded_insights = discarded_insights
|
||||
|
||||
@property
|
||||
def discarded_memory_fragments(self) -> List[T]:
|
||||
"""Return the discarded memory fragments."""
|
||||
return self._discarded_memory_fragments
|
||||
|
||||
@property
|
||||
def discarded_insights(self) -> List[InsightMemoryFragment[T]]:
|
||||
"""Return the discarded insights."""
|
||||
return self._discarded_insights
|
||||
|
||||
|
||||
class InsightExtractor(ABC, Generic[T]):
|
||||
"""Insight extractor interface.
|
||||
|
||||
Obtain high-level insights from memories.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def extract_insights(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> InsightMemoryFragment[T]:
|
||||
"""Extract insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
InsightMemoryFragment: The insights of the memory fragment.
|
||||
"""
|
||||
|
||||
|
||||
class ImportanceScorer(ABC, Generic[T]):
|
||||
"""Importance scorer interface.
|
||||
|
||||
Score the importance of memories.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def score_importance(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> float:
|
||||
"""Score the importance of memory fragment.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment.
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
float: The importance of the memory fragment.
|
||||
"""
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class Memory(ABC, Generic[T]):
|
||||
"""Memory interface."""
|
||||
|
||||
name: Optional[str] = None
|
||||
llm_client: Optional[LLMClient] = None
|
||||
importance_scorer: Optional[ImportanceScorer] = None
|
||||
insight_extractor: Optional[InsightExtractor] = None
|
||||
_real_memory_fragment_class: Optional[Type[T]] = None
|
||||
importance_weight: float = 0.15
|
||||
|
||||
@mutable
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer] = None,
|
||||
insight_extractor: Optional[InsightExtractor] = None,
|
||||
real_memory_fragment_class: Optional[Type[T]] = None,
|
||||
) -> None:
|
||||
"""Initialize memory.
|
||||
|
||||
Some agent may need to initialize memory before using it.
|
||||
"""
|
||||
self.name = name
|
||||
self.llm_client = llm_client
|
||||
self.importance_scorer = importance_scorer
|
||||
self.insight_extractor = insight_extractor
|
||||
self._real_memory_fragment_class = real_memory_fragment_class
|
||||
|
||||
@abstractmethod
|
||||
@immutable
|
||||
def structure_clone(self: M, now: Optional[datetime] = None) -> M:
|
||||
"""Return a structure clone of the memory.
|
||||
|
||||
Sometimes, we need to clone the structure of the memory, but not the content.
|
||||
|
||||
There some cases:
|
||||
|
||||
1. When we need to reset the memory, we can use this method to create a new
|
||||
one, and the new memory has the same structure as the old one.
|
||||
2. Create a new agent, the new agent has the same memory structure as the
|
||||
planner.
|
||||
|
||||
Args:
|
||||
now(Optional[datetime]): The current time
|
||||
|
||||
Returns:
|
||||
M: The structure clone of the memory
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@mutable
|
||||
def _copy_from(self, memory: "Memory") -> None:
|
||||
"""Copy memory from another memory.
|
||||
|
||||
Args:
|
||||
memory(Memory): Another memory
|
||||
"""
|
||||
self.name = memory.name
|
||||
self.llm_client = memory.llm_client
|
||||
self.importance_scorer = memory.importance_scorer
|
||||
self.insight_extractor = memory.insight_extractor
|
||||
self._real_memory_fragment_class = memory._real_memory_fragment_class
|
||||
|
||||
@abstractmethod
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to memory.
|
||||
|
||||
Two situations need to be noted here:
|
||||
1. Memory duplicated, the same/similar action is executed multiple times and
|
||||
the same/similar observation from environment is received.
|
||||
|
||||
2.Memory overflow. The memory is full and the new memory fragment needs to be
|
||||
written to memory, the common strategy is to discard some memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
now(Optional[datetime]): The current time
|
||||
op(WriteOperation): Write operation
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded.
|
||||
"""
|
||||
|
||||
@mutable
|
||||
async def write_batch(
|
||||
self, memory_fragments: List[T], now: Optional[datetime] = None
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a batch of memory fragments to memory.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
now(Optional[datetime]): The current time
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
r"""Read memory fragments by observation.
|
||||
|
||||
Usually, there three commonly used criteria for information extraction, that is,
|
||||
the recency, relevance, and importance
|
||||
|
||||
Memories that are more recent, relevant, and important are more likely to be
|
||||
extracted. Formally, we conclude the following equation from existing
|
||||
literature for memory information extraction:
|
||||
|
||||
.. math::
|
||||
|
||||
m^* = \arg\min_{m \in M} \alpha s^{\text{rec}}(q, m) + \\
|
||||
\beta s^{\text{rel}}(q, m) + \gamma s^{\text{imp}}(m), \tag{1}
|
||||
|
||||
Args:
|
||||
observation(str): observation(Query)
|
||||
alpha(float, optional): Recency coefficient. Default is None.
|
||||
beta(float, optional): Relevance coefficient. Default is None.
|
||||
gamma(float, optional): Importance coefficient. Default is None.
|
||||
|
||||
Returns:
|
||||
List[T]: memory fragments
|
||||
"""
|
||||
|
||||
@immutable
|
||||
async def reflect(self, memory_fragments: List[T]) -> List[T]:
|
||||
"""Reflect memory fragments by observation.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): memory fragments to be reflected.
|
||||
|
||||
Returns:
|
||||
List[T]: memory fragments after reflection.
|
||||
"""
|
||||
return memory_fragments
|
||||
|
||||
@immutable
|
||||
async def handle_duplicated(
|
||||
self, memory_fragments: List[T], new_memory_fragments: List[T]
|
||||
) -> List[T]:
|
||||
"""Handle duplicated memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
new_memory_fragments(List[T]): New memory fragments
|
||||
|
||||
Returns:
|
||||
List[T]: The new memory fragments after handling duplicated memory
|
||||
fragments.
|
||||
"""
|
||||
return memory_fragments + new_memory_fragments
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle memory overflow.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
|
||||
Returns:
|
||||
Tuple[List[T], List[T]]: The memory fragments after handling overflow and
|
||||
the discarded memory fragments.
|
||||
"""
|
||||
return memory_fragments, []
|
||||
|
||||
@abstractmethod
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments.
|
||||
|
||||
Returns:
|
||||
List[T]: The all cleared memory fragments.
|
||||
"""
|
||||
|
||||
@immutable
|
||||
async def get_insights(
|
||||
self, memory_fragments: List[T]
|
||||
) -> List[InsightMemoryFragment[T]]:
|
||||
"""Get insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
List[InsightMemoryFragment]: The insights of the memory fragments.
|
||||
"""
|
||||
if not self.insight_extractor:
|
||||
return []
|
||||
# Obtain insights in parallel from memory fragments parallel
|
||||
tasks = []
|
||||
for memory_fragment in memory_fragments:
|
||||
tasks.append(
|
||||
self.insight_extractor.extract_insights(
|
||||
memory_fragment, self.llm_client
|
||||
)
|
||||
)
|
||||
insights = await asyncio.gather(*tasks)
|
||||
result = []
|
||||
for insight in insights:
|
||||
if not insight:
|
||||
continue
|
||||
result.append(insight)
|
||||
if len(result) != len(insights):
|
||||
raise ValueError(
|
||||
"The number of insights is not equal to the number of memory fragments."
|
||||
)
|
||||
return result
|
||||
|
||||
@immutable
|
||||
async def score_memory_importance(self, memory_fragments: List[T]) -> List[float]:
|
||||
"""Score the importance of memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Memory fragments
|
||||
|
||||
Returns:
|
||||
List[float]: The importance of memory fragments.
|
||||
"""
|
||||
if not self.importance_scorer:
|
||||
return [5 * self.importance_weight for _ in memory_fragments]
|
||||
tasks = []
|
||||
for memory_fragment in memory_fragments:
|
||||
tasks.append(
|
||||
self.importance_scorer.score_importance(
|
||||
memory_fragment, self.llm_client
|
||||
)
|
||||
)
|
||||
result = []
|
||||
for importance in await asyncio.gather(*tasks):
|
||||
real_score = importance * self.importance_weight
|
||||
result.append(real_score)
|
||||
return result
|
||||
|
||||
@property
|
||||
@immutable
|
||||
def real_memory_fragment_class(self) -> Type[T]:
|
||||
"""Return the real memory fragment class."""
|
||||
if not self._real_memory_fragment_class:
|
||||
raise ValueError("The real memory fragment class is not set.")
|
||||
return self._real_memory_fragment_class
|
||||
|
||||
|
||||
class SensoryMemory(Memory, Generic[T]):
|
||||
"""Sensory memory."""
|
||||
|
||||
importance_weight: float = 0.9
|
||||
threshold_to_short_term: float = 0.1
|
||||
|
||||
def __init__(self, buffer_size: int = 0):
|
||||
"""Create a sensory memory."""
|
||||
self._buffer_size = buffer_size
|
||||
self._fragments: List[T] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def structure_clone(
|
||||
self: "SensoryMemory[T]", now: Optional[datetime] = None
|
||||
) -> "SensoryMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: SensoryMemory[T] = SensoryMemory(buffer_size=self._buffer_size)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to sensory memory."""
|
||||
fragments = await self.handle_duplicated(self._fragments, [memory_fragment])
|
||||
discarded_fragments: List[T] = []
|
||||
if len(fragments) > self._buffer_size:
|
||||
fragments, discarded_fragments = await self.handle_overflow(fragments)
|
||||
|
||||
async with self._lock:
|
||||
await self.clear()
|
||||
self._fragments = fragments
|
||||
if not discarded_fragments:
|
||||
return None
|
||||
return DiscardedMemoryFragments(discarded_fragments, [])
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments by observation."""
|
||||
return self._fragments
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle memory overflow.
|
||||
|
||||
For sensory memory, the overflow strategy is to transfer all memory fragments
|
||||
to short-term memory.
|
||||
|
||||
Args:
|
||||
memory_fragments(List[T]): Existing memory fragments
|
||||
|
||||
Returns:
|
||||
Tuple[List[T], List[T]]: The memory fragments after handling overflow and
|
||||
the discarded memory fragments, the discarded memory fragments should
|
||||
be transferred to short-term memory.
|
||||
"""
|
||||
scores = await self.score_memory_importance(memory_fragments)
|
||||
result = []
|
||||
for i, memory in enumerate(memory_fragments):
|
||||
if scores[i] >= self.threshold_to_short_term:
|
||||
memory.update_importance(scores[i])
|
||||
result.append(memory)
|
||||
return [], result
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments."""
|
||||
# async with self._lock:
|
||||
fragments = self._fragments
|
||||
self._fragments = []
|
||||
return fragments
|
||||
|
||||
|
||||
class ShortTermMemory(Memory, Generic[T]):
|
||||
"""Short term memory.
|
||||
|
||||
All memories are stored in computer memory.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: int = 5):
|
||||
"""Create a short-term memory."""
|
||||
self._buffer_size = buffer_size
|
||||
self._fragments: List[T] = []
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def structure_clone(
|
||||
self: "ShortTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "ShortTermMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: ShortTermMemory[T] = ShortTermMemory(buffer_size=self._buffer_size)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to short-term memory.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): New memory fragment
|
||||
now(Optional[datetime]): The current time
|
||||
op(WriteOperation): Write operation
|
||||
|
||||
Returns:
|
||||
Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
|
||||
means no memory fragments are discarded. The discarded memory fragments
|
||||
should be transferred and stored in long-term memory.
|
||||
"""
|
||||
fragments = await self.handle_duplicated(self._fragments, [memory_fragment])
|
||||
|
||||
async with self._lock:
|
||||
await self.clear()
|
||||
self._fragments = fragments
|
||||
discarded_memories = await self.transfer_to_long_term(memory_fragment)
|
||||
fragments, discarded_fragments = await self.handle_overflow(self._fragments)
|
||||
self._fragments = fragments
|
||||
return discarded_memories
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments by observation."""
|
||||
return self._fragments
|
||||
|
||||
@mutable
|
||||
async def transfer_to_long_term(
|
||||
self, memory_fragment: T
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Transfer the oldest memories to long-term memory.
|
||||
|
||||
This is a very simple strategy, just transfer the oldest memories to long-term
|
||||
memory.
|
||||
"""
|
||||
if len(self._fragments) > self._buffer_size:
|
||||
overflow_cnt = len(self._fragments) - self._buffer_size
|
||||
# Just keep the most recent memories in short-term memory
|
||||
self._fragments = self._fragments[overflow_cnt:]
|
||||
# Transfer the oldest memories to long-term memory
|
||||
overflow_fragments = self._fragments[:overflow_cnt]
|
||||
insights = await self.get_insights(overflow_fragments)
|
||||
return DiscardedMemoryFragments(overflow_fragments, insights)
|
||||
else:
|
||||
return None
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear all memory fragments."""
|
||||
# async with self._lock:
|
||||
fragments = self._fragments
|
||||
self._fragments = []
|
||||
return fragments
|
||||
|
||||
@property
|
||||
@immutable
|
||||
def short_term_memories(self) -> List[T]:
|
||||
"""Return short-term memories."""
|
||||
return self._fragments
|
19
dbgpt/agent/core/memory/gpts/__init__.py
Normal file
19
dbgpt/agent/core/memory/gpts/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Memory module for GPTS messages and plans.
|
||||
|
||||
It stores the messages and plans generated of multiple agents in the conversation.
|
||||
|
||||
It is different from the agent memory as it is a formatted structure to store the
|
||||
messages and plans, and it can be stored in a database or a file.
|
||||
"""
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
GptsMessage,
|
||||
GptsMessageMemory,
|
||||
GptsPlan,
|
||||
GptsPlansMemory,
|
||||
)
|
||||
from .default_gpts_memory import ( # noqa: F401
|
||||
DefaultGptsMessageMemory,
|
||||
DefaultGptsPlansMemory,
|
||||
)
|
||||
from .gpts_memory import GptsMemory # noqa: F401
|
250
dbgpt/agent/core/memory/gpts/base.py
Normal file
250
dbgpt/agent/core/memory/gpts/base.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Base memory interface for agents."""
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ...schema import Status
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GptsPlan:
|
||||
"""Gpts plan."""
|
||||
|
||||
conv_id: str
|
||||
sub_task_num: int
|
||||
sub_task_content: Optional[str]
|
||||
sub_task_title: Optional[str] = None
|
||||
sub_task_agent: Optional[str] = None
|
||||
resource_name: Optional[str] = None
|
||||
rely: Optional[str] = None
|
||||
agent_model: Optional[str] = None
|
||||
retry_times: int = 0
|
||||
max_retry_times: int = 5
|
||||
state: Optional[str] = Status.TODO.value
|
||||
result: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> "GptsPlan":
|
||||
"""Create a GptsPlan object from a dictionary."""
|
||||
return GptsPlan(
|
||||
conv_id=d["conv_id"],
|
||||
sub_task_num=d["sub_task_num"],
|
||||
sub_task_content=d["sub_task_content"],
|
||||
sub_task_agent=d["sub_task_agent"],
|
||||
resource_name=d["resource_name"],
|
||||
rely=d["rely"],
|
||||
agent_model=d["agent_model"],
|
||||
retry_times=d["retry_times"],
|
||||
max_retry_times=d["max_retry_times"],
|
||||
state=d["state"],
|
||||
result=d["result"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Return a dictionary representation of the GptsPlan object."""
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GptsMessage:
|
||||
"""Gpts message."""
|
||||
|
||||
conv_id: str
|
||||
sender: str
|
||||
|
||||
receiver: str
|
||||
role: str
|
||||
content: str
|
||||
rounds: Optional[int]
|
||||
current_goal: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
review_info: Optional[str] = None
|
||||
action_report: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
created_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = dataclasses.field(default_factory=datetime.utcnow)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> "GptsMessage":
|
||||
"""Create a GptsMessage object from a dictionary."""
|
||||
return GptsMessage(
|
||||
conv_id=d["conv_id"],
|
||||
sender=d["sender"],
|
||||
receiver=d["receiver"],
|
||||
role=d["role"],
|
||||
content=d["content"],
|
||||
rounds=d["rounds"],
|
||||
model_name=d["model_name"],
|
||||
current_goal=d["current_goal"],
|
||||
context=d["context"],
|
||||
review_info=d["review_info"],
|
||||
action_report=d["action_report"],
|
||||
created_at=d["created_at"],
|
||||
updated_at=d["updated_at"],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Return a dictionary representation of the GptsMessage object."""
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
|
||||
class GptsPlansMemory(ABC):
|
||||
"""Gpts plans memory interface."""
|
||||
|
||||
@abstractmethod
|
||||
def batch_save(self, plans: List[GptsPlan]) -> None:
|
||||
"""Save plans in batch.
|
||||
|
||||
Args:
|
||||
plans: panner generate plans info
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id.
|
||||
|
||||
Args:
|
||||
conv_id: conversation id
|
||||
|
||||
Returns:
|
||||
List[GptsPlan]: List of planning steps
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_conv_id_and_num(
|
||||
self, conv_id: str, task_nums: List[int]
|
||||
) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id and task number.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
task_nums(List[int]): List of sequence numbers of plans in the same
|
||||
conversation
|
||||
|
||||
Returns:
|
||||
List[GptsPlan]: List of planning steps
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get unfinished planning steps.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
|
||||
Returns:
|
||||
List[GptsPlan]: List of planning steps
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def complete_task(self, conv_id: str, task_num: int, result: str) -> None:
|
||||
"""Set the planning step to complete.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
task_num(int): Planning step num
|
||||
result(str): Plan step results
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_task(
|
||||
self,
|
||||
conv_id: str,
|
||||
task_num: int,
|
||||
state: str,
|
||||
retry_times: int,
|
||||
agent: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update planning step information.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
task_num(int): Planning step num
|
||||
state(str): the status to update to
|
||||
retry_times(int): Latest number of retries
|
||||
agent(str): Agent's name
|
||||
model(str): Model name
|
||||
result(str): Plan step results
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_by_conv_id(self, conv_id: str) -> None:
|
||||
"""Remove plan by conversation id.
|
||||
|
||||
Args:
|
||||
conv_id(str): conversation id
|
||||
"""
|
||||
|
||||
|
||||
class GptsMessageMemory(ABC):
|
||||
"""Gpts message memory interface."""
|
||||
|
||||
@abstractmethod
|
||||
def append(self, message: GptsMessage) -> None:
|
||||
"""Add a message.
|
||||
|
||||
Args:
|
||||
message(GptsMessage): Message object
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
|
||||
"""Return all messages of the agent in the conversation.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
agent(str): Agent's name
|
||||
|
||||
Returns:
|
||||
List[GptsMessage]: List of messages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_between_agents(
|
||||
self,
|
||||
conv_id: str,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
current_goal: Optional[str] = None,
|
||||
) -> List[GptsMessage]:
|
||||
"""Get messages between two agents.
|
||||
|
||||
Query information related to an agent
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
agent1(str): Agent1's name
|
||||
agent2(str): Agent2's name
|
||||
current_goal(str): Current goal
|
||||
|
||||
Returns:
|
||||
List[GptsMessage]: List of messages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsMessage]:
|
||||
"""Return all messages in the conversation.
|
||||
|
||||
Query messages by conv id.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
Returns:
|
||||
List[GptsMessage]: List of messages
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
|
||||
"""Return the last message in the conversation.
|
||||
|
||||
Args:
|
||||
conv_id(str): Conversation id
|
||||
|
||||
Returns:
|
||||
GptsMessage: The last message in the conversation
|
||||
"""
|
149
dbgpt/agent/core/memory/gpts/default_gpts_memory.py
Normal file
149
dbgpt/agent/core/memory/gpts/default_gpts_memory.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Default memory for storing plans and messages."""
|
||||
|
||||
from dataclasses import fields
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ...schema import Status
|
||||
from .base import GptsMessage, GptsMessageMemory, GptsPlan, GptsPlansMemory
|
||||
|
||||
|
||||
class DefaultGptsPlansMemory(GptsPlansMemory):
|
||||
"""Default memory for storing plans."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a memory to store plans."""
|
||||
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsPlan)])
|
||||
|
||||
def batch_save(self, plans: list[GptsPlan]):
|
||||
"""Save plans in batch."""
|
||||
new_rows = pd.DataFrame([item.to_dict() for item in plans])
|
||||
self.df = pd.concat([self.df, new_rows], ignore_index=True)
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id."""
|
||||
result = self.df.query(f"conv_id==@conv_id") # noqa: F541
|
||||
plans = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
plans.append(GptsPlan.from_dict(row_dict))
|
||||
return plans
|
||||
|
||||
def get_by_conv_id_and_num(
|
||||
self, conv_id: str, task_nums: List[int]
|
||||
) -> List[GptsPlan]:
|
||||
"""Get plans by conv_id and task number."""
|
||||
task_nums_int = [int(num) for num in task_nums] # noqa:F841
|
||||
result = self.df.query( # noqa
|
||||
f"conv_id==@conv_id and sub_task_num in @task_nums_int" # noqa
|
||||
)
|
||||
plans = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
plans.append(GptsPlan.from_dict(row_dict))
|
||||
return plans
|
||||
|
||||
def get_todo_plans(self, conv_id: str) -> List[GptsPlan]:
|
||||
"""Get unfinished planning steps."""
|
||||
todo_states = [Status.TODO.value, Status.RETRYING.value] # noqa: F841
|
||||
result = self.df.query(f"conv_id==@conv_id and state in @todo_states") # noqa
|
||||
plans = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
plans.append(GptsPlan.from_dict(row_dict))
|
||||
return plans
|
||||
|
||||
def complete_task(self, conv_id: str, task_num: int, result: str):
|
||||
"""Set the planning step to complete."""
|
||||
condition = (self.df["conv_id"] == conv_id) & (
|
||||
self.df["sub_task_num"] == task_num
|
||||
)
|
||||
self.df.loc[condition, "state"] = Status.COMPLETE.value
|
||||
self.df.loc[condition, "result"] = result
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
conv_id: str,
|
||||
task_num: int,
|
||||
state: str,
|
||||
retry_times: int,
|
||||
agent: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
):
|
||||
"""Update the state of the planning step."""
|
||||
condition = (self.df["conv_id"] == conv_id) & (
|
||||
self.df["sub_task_num"] == task_num
|
||||
)
|
||||
self.df.loc[condition, "state"] = state
|
||||
self.df.loc[condition, "retry_times"] = retry_times
|
||||
self.df.loc[condition, "result"] = result
|
||||
|
||||
if agent:
|
||||
self.df.loc[condition, "sub_task_agent"] = agent
|
||||
|
||||
if model:
|
||||
self.df.loc[condition, "agent_model"] = model
|
||||
|
||||
def remove_by_conv_id(self, conv_id: str):
|
||||
"""Remove all plans in the conversation."""
|
||||
self.df.drop(self.df[self.df["conv_id"] == conv_id].index, inplace=True)
|
||||
|
||||
|
||||
class DefaultGptsMessageMemory(GptsMessageMemory):
|
||||
"""Default memory for storing messages."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a memory to store messages."""
|
||||
self.df = pd.DataFrame(columns=[field.name for field in fields(GptsMessage)])
|
||||
|
||||
def append(self, message: GptsMessage):
|
||||
"""Append a message to the memory."""
|
||||
self.df.loc[len(self.df)] = message.to_dict()
|
||||
|
||||
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
|
||||
"""Get all messages sent or received by the agent in the conversation."""
|
||||
result = self.df.query(
|
||||
f"conv_id==@conv_id and (sender==@agent or receiver==@agent)" # noqa: F541
|
||||
)
|
||||
messages = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
messages.append(GptsMessage.from_dict(row_dict))
|
||||
return messages
|
||||
|
||||
def get_between_agents(
|
||||
self,
|
||||
conv_id: str,
|
||||
agent1: str,
|
||||
agent2: str,
|
||||
current_goal: Optional[str] = None,
|
||||
) -> List[GptsMessage]:
|
||||
"""Get all messages between two agents in the conversation."""
|
||||
if current_goal:
|
||||
result = self.df.query(
|
||||
f"conv_id==@conv_id and ((sender==@agent1 and receiver==@agent2) or (sender==@agent2 and receiver==@agent1)) and current_goal==@current_goal" # noqa
|
||||
)
|
||||
else:
|
||||
result = self.df.query(
|
||||
f"conv_id==@conv_id and ((sender==@agent1 and receiver==@agent2) or (sender==@agent2 and receiver==@agent1))" # noqa
|
||||
)
|
||||
messages = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
messages.append(GptsMessage.from_dict(row_dict))
|
||||
return messages
|
||||
|
||||
def get_by_conv_id(self, conv_id: str) -> List[GptsMessage]:
|
||||
"""Get all messages in the conversation."""
|
||||
result = self.df.query(f"conv_id==@conv_id") # noqa: F541
|
||||
messages = []
|
||||
for row in result.itertuples(index=False, name=None):
|
||||
row_dict = dict(zip(self.df.columns, row))
|
||||
messages.append(GptsMessage.from_dict(row_dict))
|
||||
return messages
|
||||
|
||||
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
|
||||
"""Get the last message in the conversation."""
|
||||
return None
|
190
dbgpt/agent/core/memory/gpts/gpts_memory.py
Normal file
190
dbgpt/agent/core/memory/gpts/gpts_memory.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""GPTs memory."""
|
||||
|
||||
import json
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.vis.client import VisAgentMessages, VisAgentPlans, vis_client
|
||||
|
||||
from ...action.base import ActionOutput
|
||||
from .base import GptsMessage, GptsMessageMemory, GptsPlansMemory
|
||||
from .default_gpts_memory import DefaultGptsMessageMemory, DefaultGptsPlansMemory
|
||||
|
||||
NONE_GOAL_PREFIX: str = "none_goal_count_"
|
||||
|
||||
|
||||
class GptsMemory:
|
||||
"""GPTs memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plans_memory: Optional[GptsPlansMemory] = None,
|
||||
message_memory: Optional[GptsMessageMemory] = None,
|
||||
):
|
||||
"""Create a memory to store plans and messages."""
|
||||
self._plans_memory: GptsPlansMemory = (
|
||||
plans_memory if plans_memory is not None else DefaultGptsPlansMemory()
|
||||
)
|
||||
self._message_memory: GptsMessageMemory = (
|
||||
message_memory if message_memory is not None else DefaultGptsMessageMemory()
|
||||
)
|
||||
|
||||
@property
|
||||
def plans_memory(self) -> GptsPlansMemory:
|
||||
"""Return the plans memory."""
|
||||
return self._plans_memory
|
||||
|
||||
@property
|
||||
def message_memory(self) -> GptsMessageMemory:
|
||||
"""Return the message memory."""
|
||||
return self._message_memory
|
||||
|
||||
async def _message_group_vis_build(self, message_group):
|
||||
if not message_group:
|
||||
return ""
|
||||
num: int = 0
|
||||
last_goal = next(reversed(message_group))
|
||||
last_goal_messages = message_group[last_goal]
|
||||
|
||||
last_goal_message = last_goal_messages[-1]
|
||||
vis_items = []
|
||||
|
||||
plan_temps = []
|
||||
for key, value in message_group.items():
|
||||
num = num + 1
|
||||
if key.startswith(NONE_GOAL_PREFIX):
|
||||
vis_items.append(await self._messages_to_plan_vis(plan_temps))
|
||||
plan_temps = []
|
||||
num = 0
|
||||
vis_items.append(await self._messages_to_agents_vis(value))
|
||||
else:
|
||||
num += 1
|
||||
plan_temps.append(
|
||||
{
|
||||
"name": key,
|
||||
"num": num,
|
||||
"status": "complete",
|
||||
"agent": value[0].receiver if value else "",
|
||||
"markdown": await self._messages_to_agents_vis(value),
|
||||
}
|
||||
)
|
||||
|
||||
if len(plan_temps) > 0:
|
||||
vis_items.append(await self._messages_to_plan_vis(plan_temps))
|
||||
vis_items.append(await self._messages_to_agents_vis([last_goal_message]))
|
||||
return "\n".join(vis_items)
|
||||
|
||||
async def _plan_vis_build(self, plan_group: dict[str, list]):
|
||||
num: int = 0
|
||||
plan_items = []
|
||||
for key, value in plan_group.items():
|
||||
num = num + 1
|
||||
plan_items.append(
|
||||
{
|
||||
"name": key,
|
||||
"num": num,
|
||||
"status": "complete",
|
||||
"agent": value[0].receiver if value else "",
|
||||
"markdown": await self._messages_to_agents_vis(value),
|
||||
}
|
||||
)
|
||||
return await self._messages_to_plan_vis(plan_items)
|
||||
|
||||
async def one_chat_completions_v2(self, conv_id: str):
|
||||
"""Generate a visualization of the conversation."""
|
||||
messages = self.message_memory.get_by_conv_id(conv_id=conv_id)
|
||||
temp_group: Dict[str, List[GptsMessage]] = OrderedDict()
|
||||
none_goal_count = 1
|
||||
count: int = 0
|
||||
for message in messages:
|
||||
count = count + 1
|
||||
if count == 1:
|
||||
continue
|
||||
current_goal = message.current_goal
|
||||
|
||||
last_goal = next(reversed(temp_group)) if temp_group else None
|
||||
if last_goal:
|
||||
last_goal_messages = temp_group[last_goal]
|
||||
if current_goal:
|
||||
if current_goal == last_goal:
|
||||
last_goal_messages.append(message)
|
||||
else:
|
||||
temp_group[current_goal] = [message]
|
||||
else:
|
||||
temp_group[f"{NONE_GOAL_PREFIX}{none_goal_count}"] = [message]
|
||||
none_goal_count += 1
|
||||
else:
|
||||
if current_goal:
|
||||
temp_group[current_goal] = [message]
|
||||
else:
|
||||
temp_group[f"{NONE_GOAL_PREFIX}{none_goal_count}"] = [message]
|
||||
none_goal_count += 1
|
||||
|
||||
return await self._message_group_vis_build(temp_group)
|
||||
|
||||
async def one_chat_completions(self, conv_id: str):
|
||||
"""Generate a visualization of the conversation."""
|
||||
messages = self.message_memory.get_by_conv_id(conv_id=conv_id)
|
||||
temp_group: Dict[str, List[GptsMessage]] = defaultdict(list)
|
||||
temp_messages = []
|
||||
vis_items = []
|
||||
count: int = 0
|
||||
for message in messages:
|
||||
count = count + 1
|
||||
if count == 1:
|
||||
continue
|
||||
if not message.current_goal or len(message.current_goal) <= 0:
|
||||
if len(temp_group) > 0:
|
||||
vis_items.append(await self._plan_vis_build(temp_group))
|
||||
temp_group.clear()
|
||||
|
||||
temp_messages.append(message)
|
||||
else:
|
||||
if len(temp_messages) > 0:
|
||||
vis_items.append(await self._messages_to_agents_vis(temp_messages))
|
||||
temp_messages.clear()
|
||||
|
||||
last_goal = message.current_goal
|
||||
temp_group[last_goal].append(message)
|
||||
|
||||
if len(temp_group) > 0:
|
||||
vis_items.append(await self._plan_vis_build(temp_group))
|
||||
temp_group.clear()
|
||||
if len(temp_messages) > 0:
|
||||
vis_items.append(await self._messages_to_agents_vis(temp_messages, True))
|
||||
temp_messages.clear()
|
||||
|
||||
return "\n".join(vis_items)
|
||||
|
||||
async def _messages_to_agents_vis(
|
||||
self, messages: List[GptsMessage], is_last_message: bool = False
|
||||
):
|
||||
if messages is None or len(messages) <= 0:
|
||||
return ""
|
||||
messages_view = []
|
||||
for message in messages:
|
||||
action_report_str = message.action_report
|
||||
view_info = message.content
|
||||
if action_report_str and len(action_report_str) > 0:
|
||||
action_out = ActionOutput.from_dict(json.loads(action_report_str))
|
||||
if action_out is not None and (
|
||||
action_out.is_exe_success or is_last_message
|
||||
):
|
||||
view = action_out.view
|
||||
view_info = view if view else action_out.content
|
||||
|
||||
messages_view.append(
|
||||
{
|
||||
"sender": message.sender,
|
||||
"receiver": message.receiver,
|
||||
"model": message.model_name,
|
||||
"markdown": view_info,
|
||||
}
|
||||
)
|
||||
vis_compent = vis_client.get(VisAgentMessages.vis_tag())
|
||||
return await vis_compent.display(content=messages_view)
|
||||
|
||||
async def _messages_to_plan_vis(self, messages: List[Dict]):
|
||||
if messages is None or len(messages) <= 0:
|
||||
return ""
|
||||
return await vis_client.get(VisAgentPlans.vis_tag()).display(content=messages)
|
288
dbgpt/agent/core/memory/hybrid.py
Normal file
288
dbgpt/agent/core/memory/hybrid.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Hybrid memory module.
|
||||
|
||||
This structure explicitly models the human short-term and long-term memories. The
|
||||
short-term memory temporarily buffers recent perceptions, while long-term memory
|
||||
consolidates important information over time.
|
||||
"""
|
||||
|
||||
import os.path
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Generic, List, Optional, Tuple, Type
|
||||
|
||||
from dbgpt.core import Embeddings, LLMClient
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
ImportanceScorer,
|
||||
InsightExtractor,
|
||||
Memory,
|
||||
SensoryMemory,
|
||||
ShortTermMemory,
|
||||
T,
|
||||
WriteOperation,
|
||||
)
|
||||
from .long_term import LongTermMemory
|
||||
from .short_term import EnhancedShortTermMemory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
class HybridMemory(Memory, Generic[T]):
|
||||
"""Hybrid memory for the agent."""
|
||||
|
||||
importance_weight: float = 0.9
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
now: datetime,
|
||||
sensory_memory: SensoryMemory[T],
|
||||
short_term_memory: ShortTermMemory[T],
|
||||
long_term_memory: LongTermMemory[T],
|
||||
default_insight_extractor: Optional[InsightExtractor] = None,
|
||||
default_importance_scorer: Optional[ImportanceScorer] = None,
|
||||
):
|
||||
"""Create a hybrid memory."""
|
||||
self.now = now
|
||||
self._sensory_memory = sensory_memory
|
||||
self._short_term_memory = short_term_memory
|
||||
self._long_term_memory = long_term_memory
|
||||
self._default_insight_extractor = default_insight_extractor
|
||||
self._default_importance_scorer = default_importance_scorer
|
||||
|
||||
def structure_clone(
|
||||
self: "HybridMemory[T]", now: Optional[datetime] = None
|
||||
) -> "HybridMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
now = now or self.now
|
||||
m = HybridMemory(
|
||||
now=now,
|
||||
sensory_memory=self._sensory_memory.structure_clone(now),
|
||||
short_term_memory=self._short_term_memory.structure_clone(now),
|
||||
long_term_memory=self._long_term_memory.structure_clone(now),
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
def from_chroma(
|
||||
cls,
|
||||
vstore_name: Optional[str] = "_chroma_agent_memory_",
|
||||
vstore_path: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
now: Optional[datetime] = None,
|
||||
sensory_memory: Optional[SensoryMemory[T]] = None,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
long_term_memory: Optional[LongTermMemory[T]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a hybrid memory from Chroma vector store."""
|
||||
from dbgpt.configs.model_config import DATA_DIR
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
if not embeddings:
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
|
||||
embeddings = DefaultEmbeddingFactory.openai()
|
||||
|
||||
vstore_path = vstore_path or os.path.join(DATA_DIR, "agent_memory")
|
||||
|
||||
vector_store_connector = VectorStoreConnector.from_default(
|
||||
vector_store_type="Chroma",
|
||||
embedding_fn=embeddings,
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name=vstore_name,
|
||||
persist_path=vstore_path,
|
||||
),
|
||||
)
|
||||
return cls.from_vstore(
|
||||
vector_store_connector=vector_store_connector,
|
||||
embeddings=embeddings,
|
||||
executor=executor,
|
||||
now=now,
|
||||
sensory_memory=sensory_memory,
|
||||
short_term_memory=short_term_memory,
|
||||
long_term_memory=long_term_memory,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_vstore(
|
||||
cls,
|
||||
vector_store_connector: "VectorStoreConnector",
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
now: Optional[datetime] = None,
|
||||
sensory_memory: Optional[SensoryMemory[T]] = None,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
long_term_memory: Optional[LongTermMemory[T]] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create a hybrid memory from vector store."""
|
||||
if not embeddings:
|
||||
embeddings = vector_store_connector.current_embeddings
|
||||
if not executor:
|
||||
executor = ThreadPoolExecutor()
|
||||
if not now:
|
||||
now = datetime.now()
|
||||
|
||||
if not sensory_memory:
|
||||
sensory_memory = SensoryMemory()
|
||||
if not short_term_memory:
|
||||
if not embeddings:
|
||||
raise ValueError("embeddings is required.")
|
||||
short_term_memory = EnhancedShortTermMemory(embeddings, executor)
|
||||
if not long_term_memory:
|
||||
long_term_memory = LongTermMemory(
|
||||
executor,
|
||||
vector_store_connector,
|
||||
now=now,
|
||||
)
|
||||
return cls(now, sensory_memory, short_term_memory, long_term_memory, **kwargs)
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
importance_scorer: Optional[ImportanceScorer[T]] = None,
|
||||
insight_extractor: Optional[InsightExtractor[T]] = None,
|
||||
real_memory_fragment_class: Optional[Type[T]] = None,
|
||||
) -> None:
|
||||
"""Initialize the memory.
|
||||
|
||||
It will initialize all the memories.
|
||||
"""
|
||||
memories = [
|
||||
self._sensory_memory,
|
||||
self._short_term_memory,
|
||||
self._long_term_memory,
|
||||
]
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"llm_client": llm_client,
|
||||
"importance_scorer": importance_scorer,
|
||||
"insight_extractor": insight_extractor,
|
||||
"real_memory_fragment_class": real_memory_fragment_class,
|
||||
}
|
||||
for memory in memories:
|
||||
memory.initialize(**kwargs)
|
||||
super().initialize(**kwargs)
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
# First write to sensory memory
|
||||
sen_discarded_memories = await self._sensory_memory.write(memory_fragment)
|
||||
if not sen_discarded_memories:
|
||||
return None
|
||||
short_term_discarded_memories = []
|
||||
discarded_memory_fragments = []
|
||||
discarded_insights = []
|
||||
for sen_memory in sen_discarded_memories.discarded_memory_fragments:
|
||||
# Write to short term memory
|
||||
short_discarded_memory = await self._short_term_memory.write(sen_memory)
|
||||
if short_discarded_memory:
|
||||
short_term_discarded_memories.append(short_discarded_memory)
|
||||
discarded_memory_fragments.extend(
|
||||
short_discarded_memory.discarded_memory_fragments
|
||||
)
|
||||
for insight in short_discarded_memory.discarded_insights:
|
||||
# Just keep the first insight
|
||||
discarded_insights.append(insight.insights[0])
|
||||
# Obtain the importance of insights
|
||||
insight_scores = await self.score_memory_importance(discarded_insights)
|
||||
# Get the importance of insights
|
||||
for i, ins in enumerate(discarded_insights):
|
||||
ins.update_importance(insight_scores[i])
|
||||
all_memories = discarded_memory_fragments + discarded_insights
|
||||
if self._long_term_memory:
|
||||
# Write to long term memory
|
||||
await self._long_term_memory.write_batch(all_memories, self.now)
|
||||
return None
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memories from the memory."""
|
||||
(
|
||||
retrieved_long_term_memories,
|
||||
short_term_discarded_memories,
|
||||
) = await self.fetch_memories(observation, self._short_term_memory)
|
||||
|
||||
await self.save_memories_after_retrieval(short_term_discarded_memories)
|
||||
return retrieved_long_term_memories
|
||||
|
||||
@immutable
|
||||
async def fetch_memories(
|
||||
self,
|
||||
observation: str,
|
||||
short_term_memory: Optional[ShortTermMemory[T]] = None,
|
||||
) -> Tuple[List[T], List[DiscardedMemoryFragments[T]]]:
|
||||
"""Fetch memories from long term memory.
|
||||
|
||||
If short_term_memory is provided, write the fetched memories to the short term
|
||||
memory.
|
||||
"""
|
||||
retrieved_long_term_memories = await self._long_term_memory.fetch_memories(
|
||||
observation
|
||||
)
|
||||
if not short_term_memory:
|
||||
return retrieved_long_term_memories, []
|
||||
short_term_discarded_memories: List[DiscardedMemoryFragments[T]] = []
|
||||
discarded_memory_fragments: List[T] = []
|
||||
for ltm in retrieved_long_term_memories:
|
||||
short_discarded_memory = await short_term_memory.write(
|
||||
ltm, op=WriteOperation.RETRIEVAL
|
||||
)
|
||||
if short_discarded_memory:
|
||||
short_term_discarded_memories.append(short_discarded_memory)
|
||||
discarded_memory_fragments.extend(
|
||||
short_discarded_memory.discarded_memory_fragments
|
||||
)
|
||||
for stm in short_term_memory.short_term_memories:
|
||||
retrieved_long_term_memories.append(
|
||||
stm.current_class.build_from(
|
||||
observation=stm.raw_observation,
|
||||
importance=stm.importance,
|
||||
)
|
||||
)
|
||||
return retrieved_long_term_memories, short_term_discarded_memories
|
||||
|
||||
async def save_memories_after_retrieval(
|
||||
self, fragments: List[DiscardedMemoryFragments[T]]
|
||||
):
|
||||
"""Save memories after retrieval."""
|
||||
discarded_memory_fragments = []
|
||||
discarded_memory_insights: List[T] = []
|
||||
for f in fragments:
|
||||
discarded_memory_fragments.extend(f.discarded_memory_fragments)
|
||||
for fi in f.discarded_insights:
|
||||
discarded_memory_insights.append(fi.insights[0])
|
||||
insights_importance = await self.score_memory_importance(
|
||||
discarded_memory_insights
|
||||
)
|
||||
for i, ins in enumerate(discarded_memory_insights):
|
||||
ins.update_importance(insights_importance[i])
|
||||
all_memories = discarded_memory_fragments + discarded_memory_insights
|
||||
await self._long_term_memory.write_batch(all_memories, self.now)
|
||||
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear the memory.
|
||||
|
||||
# TODO
|
||||
"""
|
||||
return []
|
174
dbgpt/agent/core/memory/llm.py
Normal file
174
dbgpt/agent/core/memory/llm.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""LLM Utility For Agent Memory."""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
LLMClient,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
)
|
||||
|
||||
from .base import ImportanceScorer, InsightExtractor, InsightMemoryFragment, T
|
||||
|
||||
|
||||
class BaseLLMCaller(BaseModel):
|
||||
"""Base class for LLM caller."""
|
||||
|
||||
prompt: str = ""
|
||||
model: Optional[str] = None
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompt: Union[ChatPromptTemplate, str],
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Call LLM client to generate response.
|
||||
|
||||
Args:
|
||||
llm_client(LLMClient): LLM client
|
||||
prompt(ChatPromptTemplate): prompt
|
||||
**kwargs: other keyword arguments
|
||||
|
||||
Returns:
|
||||
str: response
|
||||
"""
|
||||
if not llm_client:
|
||||
raise ValueError("LLM client is required.")
|
||||
if isinstance(prompt, str):
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[HumanPromptTemplate.from_template(prompt)]
|
||||
)
|
||||
model = self.model
|
||||
if not model:
|
||||
model = await self.get_model(llm_client)
|
||||
prompt_kwargs = {}
|
||||
prompt_kwargs.update(kwargs)
|
||||
pass_kwargs = {
|
||||
k: v for k, v in prompt_kwargs.items() if k in prompt.input_variables
|
||||
}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
model_request = ModelRequest.build_request(model, messages=model_messages)
|
||||
model_output = await llm_client.generate(model_request)
|
||||
if not model_output.success:
|
||||
raise ValueError("Call LLM failed.")
|
||||
return model_output.text
|
||||
|
||||
async def get_model(self, llm_client: LLMClient) -> str:
|
||||
"""Get the model.
|
||||
|
||||
Args:
|
||||
llm_client(LLMClient): LLM client
|
||||
|
||||
Returns:
|
||||
str: model
|
||||
"""
|
||||
models = await llm_client.models()
|
||||
if not models:
|
||||
raise ValueError("No models available.")
|
||||
self.model = models[0].model
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def _parse_list(text: str) -> List[str]:
|
||||
"""Parse a newline-separated string into a list of strings.
|
||||
|
||||
1. First, split by newline
|
||||
2. Remove whitespace from each line
|
||||
"""
|
||||
lines = re.split(r"\n", text.strip())
|
||||
lines = [line for line in lines if line.strip()] # remove empty lines
|
||||
# Use regular expression to remove the numbers and dots at the beginning of
|
||||
# each line
|
||||
return [re.sub(r"^\s*\d+\.\s*", "", line).strip() for line in lines]
|
||||
|
||||
@staticmethod
|
||||
def _parse_number(text: str, importance_weight: Optional[float] = None) -> float:
|
||||
"""Parse a number from a string."""
|
||||
match = re.search(r"^\D*(\d+)", text)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
if importance_weight is not None:
|
||||
score = (score / 10) * importance_weight
|
||||
return score
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
|
||||
class LLMInsightExtractor(BaseLLMCaller, InsightExtractor[T]):
|
||||
"""LLM Insight Extractor.
|
||||
|
||||
Get high-level insights from memories.
|
||||
"""
|
||||
|
||||
prompt: str = (
|
||||
"There are some memories: {content}\nCan you infer from the "
|
||||
"above memories the high-level insight for this person's character? The insight"
|
||||
" needs to be significantly different from the content and structure of the "
|
||||
"original memories.Respond in one sentence.\n\n"
|
||||
"Results:"
|
||||
)
|
||||
|
||||
async def extract_insights(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> InsightMemoryFragment[T]:
|
||||
"""Extract insights from memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
InsightMemoryFragment: The insights of the memory fragment.
|
||||
"""
|
||||
insights_str: str = await self.call_llm(
|
||||
self.prompt, llm_client, content=memory_fragment.raw_observation
|
||||
)
|
||||
insights_list = self._parse_list(insights_str)
|
||||
return InsightMemoryFragment(memory_fragment, insights_list)
|
||||
|
||||
|
||||
class LLMImportanceScorer(BaseLLMCaller, ImportanceScorer[T]):
|
||||
"""LLM Importance Scorer.
|
||||
|
||||
Score the importance of memories.
|
||||
"""
|
||||
|
||||
prompt: str = (
|
||||
"Please give an importance score between 1 to 10 for the following "
|
||||
"observation. Higher score indicates the observation is more important. More "
|
||||
"rules that should be followed are:"
|
||||
"\n(1): Learning experience of a certain skill is important"
|
||||
"\n(2): The occurrence of a particular event is important"
|
||||
"\n(3): User thoughts and emotions matter"
|
||||
"\n(4): More informative indicates more important."
|
||||
"Please respond with a single integer."
|
||||
"\nObservation:{content}"
|
||||
"\nRating:"
|
||||
)
|
||||
|
||||
async def score_importance(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
) -> float:
|
||||
"""Score the importance of memory fragments.
|
||||
|
||||
Args:
|
||||
memory_fragment(T): Memory fragment
|
||||
llm_client(Optional[LLMClient]): LLM client
|
||||
|
||||
Returns:
|
||||
float: The importance score of the memory fragment.
|
||||
"""
|
||||
score: str = await self.call_llm(
|
||||
self.prompt, llm_client, content=memory_fragment.raw_observation
|
||||
)
|
||||
return self._parse_number(score)
|
192
dbgpt/agent/core/memory/long_term.py
Normal file
192
dbgpt/agent/core/memory/long_term.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Long-term memory module."""
|
||||
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from typing import Generic, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.time_weighted import TimeWeightedEmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
from .base import DiscardedMemoryFragments, Memory, T, WriteOperation
|
||||
|
||||
_FORGET_PLACEHOLDER = "[FORGET]"
|
||||
_MERGE_PLACEHOLDER = "[MERGE]"
|
||||
_METADATA_BUFFER_IDX = "buffer_idx"
|
||||
_METADATA_LAST_ACCESSED_AT = "last_accessed_at"
|
||||
_METADAT_IMPORTANCE = "importance"
|
||||
|
||||
|
||||
class LongTermRetriever(TimeWeightedEmbeddingRetriever):
|
||||
"""Long-term retriever."""
|
||||
|
||||
def __init__(self, now: datetime, **kwargs):
|
||||
"""Create a long-term retriever."""
|
||||
self.now = now
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@mutable
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve memories."""
|
||||
current_time = self.now
|
||||
docs_and_scores = {
|
||||
doc.metadata[_METADATA_BUFFER_IDX]: (doc, self.default_salience)
|
||||
# Calculate for all memories.
|
||||
for doc in self.memory_stream
|
||||
}
|
||||
# If a doc is considered salient, update the salience score
|
||||
docs_and_scores.update(self.get_salient_docs(query))
|
||||
rescored_docs = [
|
||||
(doc, self._get_combined_score(doc, relevance, current_time))
|
||||
for doc, relevance in docs_and_scores.values()
|
||||
]
|
||||
rescored_docs.sort(key=lambda x: x[1], reverse=True)
|
||||
result = []
|
||||
# Ensure frequently accessed memories aren't forgotten
|
||||
retrieved_num = 0
|
||||
for doc, _ in rescored_docs:
|
||||
if (
|
||||
retrieved_num < self._k
|
||||
and doc.content.find(_FORGET_PLACEHOLDER) == -1
|
||||
and doc.content.find(_MERGE_PLACEHOLDER) == -1
|
||||
):
|
||||
retrieved_num += 1
|
||||
buffered_doc = self.memory_stream[doc.metadata[_METADATA_BUFFER_IDX]]
|
||||
buffered_doc.metadata[_METADATA_LAST_ACCESSED_AT] = current_time
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
|
||||
class LongTermMemory(Memory, Generic[T]):
|
||||
"""Long-term memory."""
|
||||
|
||||
importance_weight: float = 0.15
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: Executor,
|
||||
vector_store_connector: VectorStoreConnector,
|
||||
now: Optional[datetime] = None,
|
||||
reflection_threshold: Optional[float] = None,
|
||||
):
|
||||
"""Create a long-term memory."""
|
||||
self.now = now or datetime.now()
|
||||
self.executor = executor
|
||||
self.reflecting: bool = False
|
||||
self.forgetting: bool = False
|
||||
self.reflection_threshold: Optional[float] = reflection_threshold
|
||||
self.aggregate_importance: float = 0.0
|
||||
self._vector_store_connector = vector_store_connector
|
||||
self.memory_retriever = LongTermRetriever(
|
||||
now=self.now, vector_store_connector=vector_store_connector
|
||||
)
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "LongTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "LongTermMemory[T]":
|
||||
"""Create a structure clone of the long-term memory."""
|
||||
new_name = self.name
|
||||
if not new_name:
|
||||
raise ValueError("name is required.")
|
||||
m: LongTermMemory[T] = LongTermMemory(
|
||||
now=now,
|
||||
executor=self.executor,
|
||||
vector_store_connector=self._vector_store_connector.new_connector(new_name),
|
||||
reflection_threshold=self.reflection_threshold,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a memory fragment to the memory."""
|
||||
importance = memory_fragment.importance
|
||||
last_accessed_time = memory_fragment.last_accessed_time
|
||||
if importance is None:
|
||||
raise ValueError("importance is required.")
|
||||
if not self.reflecting:
|
||||
self.aggregate_importance += importance
|
||||
|
||||
memory_idx = len(self.memory_retriever.memory_stream)
|
||||
document = Chunk(
|
||||
page_content="[{}] ".format(memory_idx)
|
||||
+ str(memory_fragment.raw_observation),
|
||||
metadata={
|
||||
_METADAT_IMPORTANCE: importance,
|
||||
_METADATA_LAST_ACCESSED_AT: last_accessed_time,
|
||||
},
|
||||
)
|
||||
await blocking_func_to_async(
|
||||
self.executor,
|
||||
self.memory_retriever.load_document,
|
||||
[document],
|
||||
current_time=now,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@mutable
|
||||
async def write_batch(
|
||||
self, memory_fragments: List[T], now: Optional[datetime] = None
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write a batch of memory fragments to the memory."""
|
||||
current_datetime = self.now
|
||||
if not now:
|
||||
raise ValueError("Now time is required.")
|
||||
for short_term_memory in memory_fragments:
|
||||
short_term_memory.update_accessed_time(now=now)
|
||||
await self.write(short_term_memory, now=current_datetime)
|
||||
# TODO(fangyinc): Reflect on the memories and get high-level insights.
|
||||
# TODO(fangyinc): Forget memories that are not important.
|
||||
return None
|
||||
|
||||
@immutable
|
||||
async def read(
|
||||
self,
|
||||
observation: str,
|
||||
alpha: Optional[float] = None,
|
||||
beta: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
) -> List[T]:
|
||||
"""Read memory fragments related to the observation."""
|
||||
return await self.fetch_memories(observation=observation, now=self.now)
|
||||
|
||||
@immutable
|
||||
async def fetch_memories(
|
||||
self, observation: str, now: Optional[datetime] = None
|
||||
) -> List[T]:
|
||||
"""Fetch memories related to the observation."""
|
||||
# TODO: Mock now?
|
||||
retrieved_memories = []
|
||||
retrieved_list = await blocking_func_to_async(
|
||||
self.executor,
|
||||
self.memory_retriever.retrieve,
|
||||
observation,
|
||||
)
|
||||
for retrieved_chunk in retrieved_list:
|
||||
retrieved_memories.append(
|
||||
self.real_memory_fragment_class.build_from(
|
||||
observation=retrieved_chunk.content,
|
||||
importance=retrieved_chunk.metadata[_METADAT_IMPORTANCE],
|
||||
)
|
||||
)
|
||||
return retrieved_memories
|
||||
|
||||
@mutable
|
||||
async def clear(self) -> List[T]:
|
||||
"""Clear the memory.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
return []
|
203
dbgpt/agent/core/memory/short_term.py
Normal file
203
dbgpt/agent/core/memory/short_term.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Short term memory module."""
|
||||
|
||||
import random
|
||||
from concurrent.futures import Executor
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.util.annotations import immutable, mutable
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.similarity_util import cosine_similarity, sigmoid_function
|
||||
|
||||
from .base import (
|
||||
DiscardedMemoryFragments,
|
||||
InsightMemoryFragment,
|
||||
ShortTermMemory,
|
||||
T,
|
||||
WriteOperation,
|
||||
)
|
||||
|
||||
|
||||
class EnhancedShortTermMemory(ShortTermMemory[T]):
|
||||
"""Enhanced short term memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: Embeddings,
|
||||
executor: Executor,
|
||||
buffer_size: int = 2,
|
||||
enhance_similarity_threshold: float = 0.7,
|
||||
enhance_threshold: int = 3,
|
||||
):
|
||||
"""Initialize enhanced short term memory."""
|
||||
super().__init__(buffer_size=buffer_size)
|
||||
self._executor = executor
|
||||
self._embeddings = embeddings
|
||||
self.short_embeddings: List[List[float]] = []
|
||||
self.enhance_cnt: List[int] = [0 for _ in range(self._buffer_size)]
|
||||
self.enhance_memories: List[List[T]] = [[] for _ in range(self._buffer_size)]
|
||||
self.enhance_similarity_threshold = enhance_similarity_threshold
|
||||
self.enhance_threshold = enhance_threshold
|
||||
|
||||
@immutable
|
||||
def structure_clone(
|
||||
self: "EnhancedShortTermMemory[T]", now: Optional[datetime] = None
|
||||
) -> "EnhancedShortTermMemory[T]":
|
||||
"""Return a structure clone of the memory."""
|
||||
m: EnhancedShortTermMemory[T] = EnhancedShortTermMemory(
|
||||
embeddings=self._embeddings,
|
||||
executor=self._executor,
|
||||
buffer_size=self._buffer_size,
|
||||
enhance_similarity_threshold=self.enhance_similarity_threshold,
|
||||
enhance_threshold=self.enhance_threshold,
|
||||
)
|
||||
m._copy_from(self)
|
||||
return m
|
||||
|
||||
@mutable
|
||||
async def write(
|
||||
self,
|
||||
memory_fragment: T,
|
||||
now: Optional[datetime] = None,
|
||||
op: WriteOperation = WriteOperation.ADD,
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Write memory fragment to short term memory.
|
||||
|
||||
Reference: https://github.com/RUC-GSAI/YuLan-Rec/blob/main/agents/recagent_memory.py#L336 # noqa
|
||||
"""
|
||||
# Calculate current embeddings of current memory fragment
|
||||
memory_fragment_embeddings = await blocking_func_to_async(
|
||||
self._executor,
|
||||
memory_fragment.calculate_current_embeddings,
|
||||
self._embeddings.embed_documents,
|
||||
)
|
||||
memory_fragment.update_embeddings(memory_fragment_embeddings)
|
||||
for idx, memory_embedding in enumerate(self.short_embeddings):
|
||||
similarity = await blocking_func_to_async(
|
||||
self._executor,
|
||||
cosine_similarity,
|
||||
memory_embedding,
|
||||
memory_fragment_embeddings,
|
||||
)
|
||||
# Sigmoid probability, transform similarity to [0, 1]
|
||||
sigmoid_prob: float = await blocking_func_to_async(
|
||||
self._executor, sigmoid_function, similarity
|
||||
)
|
||||
if (
|
||||
sigmoid_prob >= self.enhance_similarity_threshold
|
||||
and random.random() < sigmoid_prob
|
||||
):
|
||||
self.enhance_cnt[idx] += 1
|
||||
self.enhance_memories[idx].append(memory_fragment)
|
||||
discard_memories = await self.transfer_to_long_term(memory_fragment)
|
||||
if op == WriteOperation.ADD:
|
||||
self._fragments.append(memory_fragment)
|
||||
self.short_embeddings.append(memory_fragment_embeddings)
|
||||
await self.handle_overflow(self._fragments)
|
||||
return discard_memories
|
||||
|
||||
@mutable
|
||||
async def transfer_to_long_term(
|
||||
self, memory_fragment: T
|
||||
) -> Optional[DiscardedMemoryFragments[T]]:
|
||||
"""Transfer memory fragment to long term memory."""
|
||||
transfer_flag = False
|
||||
existing_memory = [True for _ in range(len(self.short_term_memories))]
|
||||
|
||||
enhance_memories: List[T] = []
|
||||
to_get_insight_memories: List[T] = []
|
||||
for idx, memory in enumerate(self.short_term_memories):
|
||||
# if exceed the enhancement threshold
|
||||
if (
|
||||
self.enhance_cnt[idx] >= self.enhance_threshold
|
||||
and existing_memory[idx] is True
|
||||
):
|
||||
existing_memory[idx] = False
|
||||
transfer_flag = True
|
||||
#
|
||||
# short-term memories
|
||||
content = [memory]
|
||||
# do not repeatedly add observation memory to summary, so use [:-1].
|
||||
for enhance_memory in self.enhance_memories[idx][:-1]:
|
||||
content.append(enhance_memory)
|
||||
content.append(memory_fragment)
|
||||
# Merge the enhanced memories to single memory
|
||||
merged_enhance_memory: T = memory.reduce(
|
||||
content, merged_memory=memory.importance
|
||||
)
|
||||
to_get_insight_memories.append(merged_enhance_memory)
|
||||
enhance_memories.append(merged_enhance_memory)
|
||||
# Get insights for the every enhanced memory
|
||||
enhance_insights: List[InsightMemoryFragment] = await self.get_insights(
|
||||
to_get_insight_memories
|
||||
)
|
||||
|
||||
if transfer_flag:
|
||||
# re-construct the indexes of short-term memories after removing summarized
|
||||
# memories
|
||||
new_memories: List[T] = []
|
||||
new_embeddings: List[List[float]] = []
|
||||
new_enhance_memories: List[List[T]] = [[] for _ in range(self._buffer_size)]
|
||||
new_enhance_cnt: List[int] = [0 for _ in range(self._buffer_size)]
|
||||
for idx, memory in enumerate(self.short_term_memories):
|
||||
if existing_memory[idx]:
|
||||
# Remove not enhanced memories to new memories
|
||||
new_enhance_memories[len(new_memories)] = self.enhance_memories[idx]
|
||||
new_enhance_cnt[len(new_memories)] = self.enhance_cnt[idx]
|
||||
new_memories.append(memory)
|
||||
new_embeddings.append(self.short_embeddings[idx])
|
||||
self._fragments = new_memories
|
||||
self.short_embeddings = new_embeddings
|
||||
self.enhance_memories = new_enhance_memories
|
||||
self.enhance_cnt = new_enhance_cnt
|
||||
return DiscardedMemoryFragments(enhance_memories, enhance_insights)
|
||||
|
||||
@mutable
|
||||
async def handle_overflow(
|
||||
self, memory_fragments: List[T]
|
||||
) -> Tuple[List[T], List[T]]:
|
||||
"""Handle overflow of short term memory.
|
||||
|
||||
Discard the least important memory fragment if the buffer size exceeds.
|
||||
"""
|
||||
if len(self.short_term_memories) > self._buffer_size:
|
||||
id2fragments: Dict[int, Dict] = {}
|
||||
for idx in range(len(self.short_term_memories) - 1):
|
||||
# Not discard the last one
|
||||
memory = self.short_term_memories[idx]
|
||||
id2fragments[idx] = {
|
||||
"enhance_count": self.enhance_cnt[idx],
|
||||
"importance": memory.importance,
|
||||
}
|
||||
# Sort by importance and enhance count, first discard the least important
|
||||
sorted_ids = sorted(
|
||||
id2fragments.keys(),
|
||||
key=lambda x: (
|
||||
id2fragments[x]["importance"],
|
||||
id2fragments[x]["enhance_count"],
|
||||
),
|
||||
)
|
||||
pop_id = sorted_ids[0]
|
||||
pop_raw_observation = self.short_term_memories[pop_id].raw_observation
|
||||
self.enhance_cnt.pop(pop_id)
|
||||
self.enhance_cnt.append(0)
|
||||
self.enhance_memories.pop(pop_id)
|
||||
self.enhance_memories.append([])
|
||||
|
||||
discard_memory = self._fragments.pop(pop_id)
|
||||
self.short_embeddings.pop(pop_id)
|
||||
|
||||
# remove the discard_memory from other short-term memory's enhanced list
|
||||
for idx in range(len(self.short_term_memories)):
|
||||
current_enhance_memories: List[T] = self.enhance_memories[idx]
|
||||
to_remove_idx = []
|
||||
for i, ehf in enumerate(current_enhance_memories):
|
||||
if ehf.raw_observation == pop_raw_observation:
|
||||
to_remove_idx.append(i)
|
||||
for i in to_remove_idx:
|
||||
current_enhance_memories.pop(i)
|
||||
self.enhance_cnt[idx] -= len(to_remove_idx)
|
||||
|
||||
return memory_fragments, [discard_memory]
|
||||
return memory_fragments, []
|
36
dbgpt/agent/core/plan/__init__.py
Normal file
36
dbgpt/agent/core/plan/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Plan module for the agent."""
|
||||
|
||||
from .awel.agent_operator import ( # noqa: F401
|
||||
AgentDummyTrigger,
|
||||
AWELAgentOperator,
|
||||
WrappedAgentOperator,
|
||||
)
|
||||
from .awel.agent_operator_resource import ( # noqa: F401
|
||||
AWELAgent,
|
||||
AWELAgentConfig,
|
||||
AWELAgentResource,
|
||||
)
|
||||
from .awel.team_awel_layout import ( # noqa: F401
|
||||
AWELTeamContext,
|
||||
DefaultAWELLayoutManager,
|
||||
WrappedAWELLayoutManager,
|
||||
)
|
||||
from .plan_action import PlanAction, PlanInput # noqa: F401
|
||||
from .planner_agent import PlannerAgent # noqa: F401
|
||||
from .team_auto_plan import AutoPlanChatManager # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PlanAction",
|
||||
"PlanInput",
|
||||
"PlannerAgent",
|
||||
"AutoPlanChatManager",
|
||||
"AWELAgent",
|
||||
"AWELAgentConfig",
|
||||
"AWELAgentResource",
|
||||
"AWELTeamContext",
|
||||
"DefaultAWELLayoutManager",
|
||||
"WrappedAWELLayoutManager",
|
||||
"AgentDummyTrigger",
|
||||
"AWELAgentOperator",
|
||||
"WrappedAgentOperator",
|
||||
]
|
4
dbgpt/agent/core/plan/awel/__init__.py
Normal file
4
dbgpt/agent/core/plan/awel/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""External planner.
|
||||
|
||||
Use AWEL as the external planner.
|
||||
"""
|
311
dbgpt/agent/core/plan/awel/agent_operator.py
Normal file
311
dbgpt/agent/core/plan/awel/agent_operator.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Agent Operator for AWEL."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.core.awel.trigger.base import Trigger
|
||||
from dbgpt.core.interface.message import ModelMessageRoleType
|
||||
|
||||
# TODO: Don't dependent on MixinLLMOperator
|
||||
from dbgpt.model.operators.llm_operator import MixinLLMOperator
|
||||
|
||||
from ....util.llm.llm import LLMConfig
|
||||
from ...agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...agent_manage import get_agent_manager
|
||||
from ...base_agent import ConversableAgent
|
||||
from .agent_operator_resource import AWELAgent
|
||||
|
||||
|
||||
class BaseAgentOperator:
|
||||
"""The abstract operator for an Agent."""
|
||||
|
||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_agent_name"
|
||||
|
||||
def __init__(self, agent: Optional[Agent] = None):
|
||||
"""Create an AgentOperator."""
|
||||
self._agent = agent
|
||||
|
||||
@property
|
||||
def agent(self) -> Agent:
|
||||
"""Return the Agent."""
|
||||
if not self._agent:
|
||||
raise ValueError("agent is not set")
|
||||
return self._agent
|
||||
|
||||
|
||||
class WrappedAgentOperator(
|
||||
BaseAgentOperator, MapOperator[AgentGenerateContext, AgentGenerateContext], ABC
|
||||
):
|
||||
"""The Agent operator.
|
||||
|
||||
Wrap the agent and trigger the agent to generate a reply.
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Agent, **kwargs):
|
||||
"""Create an WrappedAgentOperator."""
|
||||
super().__init__(agent=agent)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, input_value: AgentGenerateContext) -> AgentGenerateContext:
|
||||
"""Trigger agent to generate a reply."""
|
||||
now_rely_messages: List[AgentMessage] = []
|
||||
if not input_value.message:
|
||||
raise ValueError("The message is empty.")
|
||||
input_message = input_value.message.copy()
|
||||
|
||||
# Isolate the message delivery mechanism and pass it to the operator
|
||||
_goal = self.agent.name if self.agent.name else self.agent.role
|
||||
current_goal = f"[{_goal}]:"
|
||||
|
||||
if input_message.content:
|
||||
current_goal += input_message.content
|
||||
input_message.current_goal = current_goal
|
||||
|
||||
# What was received was the User message
|
||||
human_message = input_message.copy()
|
||||
human_message.role = ModelMessageRoleType.HUMAN
|
||||
now_rely_messages.append(human_message)
|
||||
|
||||
# Send a message (no reply required) and pass the message content
|
||||
now_message = input_message
|
||||
if input_value.rely_messages and len(input_value.rely_messages) > 0:
|
||||
now_message = input_value.rely_messages[-1]
|
||||
if not input_value.sender:
|
||||
raise ValueError("The sender is empty.")
|
||||
await input_value.sender.send(
|
||||
now_message, self.agent, input_value.reviewer, False
|
||||
)
|
||||
|
||||
agent_reply_message = await self.agent.generate_reply(
|
||||
received_message=input_message,
|
||||
sender=input_value.sender,
|
||||
reviewer=input_value.reviewer,
|
||||
rely_messages=input_value.rely_messages,
|
||||
)
|
||||
is_success = agent_reply_message.success
|
||||
|
||||
if not is_success:
|
||||
raise ValueError(
|
||||
f"The task failed at step {self.agent.role} and the attempt "
|
||||
f"to repair it failed. The final reason for "
|
||||
f"failure:{agent_reply_message.content}!"
|
||||
)
|
||||
|
||||
# What is sent is an AI message
|
||||
ai_message = agent_reply_message.copy()
|
||||
ai_message.role = ModelMessageRoleType.AI
|
||||
|
||||
now_rely_messages.append(ai_message)
|
||||
|
||||
# Handle user goals and outcome dependencies
|
||||
return AgentGenerateContext(
|
||||
message=input_message,
|
||||
sender=self.agent,
|
||||
reviewer=input_value.reviewer,
|
||||
# Default single step transfer of information
|
||||
rely_messages=now_rely_messages,
|
||||
silent=input_value.silent,
|
||||
)
|
||||
|
||||
|
||||
class AWELAgentOperator(
|
||||
MixinLLMOperator, MapOperator[AgentGenerateContext, AgentGenerateContext]
|
||||
):
|
||||
"""The Agent operator for AWEL."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="AWEL Agent Operator",
|
||||
name="agent_operator",
|
||||
category=OperatorCategory.AGENT,
|
||||
description="The Agent operator.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Agent",
|
||||
"awel_agent",
|
||||
AWELAgent,
|
||||
description="The dbgpt agent.",
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Agent Operator Request",
|
||||
"agent_operator_request",
|
||||
AgentGenerateContext,
|
||||
"The Agent Operator request.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Agent Operator Output",
|
||||
"agent_operator_output",
|
||||
AgentGenerateContext,
|
||||
description="The Agent Operator output.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, awel_agent: AWELAgent, **kwargs):
|
||||
"""Create an AgentOperator."""
|
||||
MixinLLMOperator.__init__(self)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self.awel_agent = awel_agent
|
||||
|
||||
async def map(
|
||||
self,
|
||||
input_value: AgentGenerateContext,
|
||||
) -> AgentGenerateContext:
|
||||
"""Trigger agent to generate a reply."""
|
||||
if not input_value.message:
|
||||
raise ValueError("The message is empty.")
|
||||
input_message = input_value.message.copy()
|
||||
agent = await self.get_agent(input_value)
|
||||
if agent.fixed_subgoal and len(agent.fixed_subgoal) > 0:
|
||||
# Isolate the message delivery mechanism and pass it to the operator
|
||||
current_goal = f"[{agent.name if agent.name else agent.role}]:"
|
||||
if agent.fixed_subgoal:
|
||||
current_goal += agent.fixed_subgoal
|
||||
input_message.current_goal = current_goal
|
||||
input_message.content = agent.fixed_subgoal
|
||||
else:
|
||||
# Isolate the message delivery mechanism and pass it to the operator
|
||||
current_goal = f"[{agent.name if agent.name else agent.role}]:"
|
||||
if input_message.content:
|
||||
current_goal += input_message.content
|
||||
input_message.current_goal = current_goal
|
||||
|
||||
now_rely_messages: List[AgentMessage] = []
|
||||
# What was received was the User message
|
||||
human_message = input_message.copy()
|
||||
human_message.role = ModelMessageRoleType.HUMAN
|
||||
now_rely_messages.append(human_message)
|
||||
|
||||
# Send a message (no reply required) and pass the message content
|
||||
|
||||
now_message = input_message
|
||||
if input_value.rely_messages and len(input_value.rely_messages) > 0:
|
||||
now_message = input_value.rely_messages[-1]
|
||||
sender = input_value.sender
|
||||
if not sender:
|
||||
raise ValueError("The sender is empty.")
|
||||
await sender.send(now_message, agent, input_value.reviewer, False)
|
||||
|
||||
agent_reply_message = await agent.generate_reply(
|
||||
received_message=input_message,
|
||||
sender=sender,
|
||||
reviewer=input_value.reviewer,
|
||||
rely_messages=input_value.rely_messages,
|
||||
)
|
||||
|
||||
is_success = agent_reply_message.success
|
||||
|
||||
if not is_success:
|
||||
raise ValueError(
|
||||
f"The task failed at step {agent.role} and the attempt to "
|
||||
f"repair it failed. The final reason for "
|
||||
f"failure:{agent_reply_message.content}!"
|
||||
)
|
||||
|
||||
# What is sent is an AI message
|
||||
ai_message: AgentMessage = agent_reply_message.copy()
|
||||
ai_message.role = ModelMessageRoleType.AI
|
||||
now_rely_messages.append(ai_message)
|
||||
|
||||
# Handle user goals and outcome dependencies
|
||||
return AgentGenerateContext(
|
||||
message=input_message,
|
||||
sender=agent,
|
||||
reviewer=input_value.reviewer,
|
||||
# Default single step transfer of information
|
||||
rely_messages=now_rely_messages,
|
||||
silent=input_value.silent,
|
||||
memory=input_value.memory.structure_clone() if input_value.memory else None,
|
||||
agent_context=input_value.agent_context,
|
||||
resource_loader=input_value.resource_loader,
|
||||
llm_client=input_value.llm_client,
|
||||
round_index=agent.consecutive_auto_reply_counter,
|
||||
)
|
||||
|
||||
async def get_agent(
|
||||
self,
|
||||
input_value: AgentGenerateContext,
|
||||
) -> ConversableAgent:
|
||||
"""Build the agent."""
|
||||
# agent build
|
||||
agent_cls: Type[ConversableAgent] = get_agent_manager().get_by_name(
|
||||
self.awel_agent.agent_profile
|
||||
)
|
||||
llm_config = self.awel_agent.llm_config
|
||||
|
||||
if not llm_config:
|
||||
if input_value.llm_client:
|
||||
llm_config = LLMConfig(llm_client=input_value.llm_client)
|
||||
else:
|
||||
llm_config = LLMConfig(llm_client=self.llm_client)
|
||||
else:
|
||||
if not llm_config.llm_client:
|
||||
if input_value.llm_client:
|
||||
llm_config.llm_client = input_value.llm_client
|
||||
else:
|
||||
llm_config.llm_client = self.llm_client
|
||||
|
||||
kwargs = {}
|
||||
if self.awel_agent.role_name:
|
||||
kwargs["name"] = self.awel_agent.role_name
|
||||
if self.awel_agent.fixed_subgoal:
|
||||
kwargs["fixed_subgoal"] = self.awel_agent.fixed_subgoal
|
||||
|
||||
agent = (
|
||||
await agent_cls(**kwargs)
|
||||
.bind(input_value.memory)
|
||||
.bind(llm_config)
|
||||
.bind(input_value.agent_context)
|
||||
.bind(self.awel_agent.resources)
|
||||
.bind(input_value.resource_loader)
|
||||
.build()
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
class AgentDummyTrigger(Trigger):
|
||||
"""Http trigger for AWEL.
|
||||
|
||||
Http trigger is used to trigger a DAG by http request.
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Agent Trigger",
|
||||
name="agent_trigger",
|
||||
category=OperatorCategory.AGENT,
|
||||
operator_type=OperatorType.INPUT,
|
||||
description="Trigger your workflow by agent",
|
||||
inputs=[],
|
||||
parameters=[],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Agent Operator Context",
|
||||
"agent_operator_context",
|
||||
AgentGenerateContext,
|
||||
description="The Agent Operator output.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize a HttpTrigger."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def trigger(self, **kwargs) -> None:
|
||||
"""Trigger the DAG. Not used in HttpTrigger."""
|
||||
raise NotImplementedError("Dummy trigger does not support trigger.")
|
209
dbgpt/agent/core/plan/awel/agent_operator_resource.py
Normal file
209
dbgpt/agent/core/plan/awel/agent_operator_resource.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""The AWEL Agent Operator Resource."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel.flow import (
|
||||
FunctionDynamicOptions,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ResourceCategory,
|
||||
register_resource,
|
||||
)
|
||||
|
||||
from ....resource.resource_api import AgentResource, ResourceType
|
||||
from ....util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ...agent_manage import get_agent_manager
|
||||
|
||||
|
||||
@register_resource(
|
||||
label="AWEL Agent Resource",
|
||||
name="agent_operator_resource",
|
||||
description="The Agent Resource.",
|
||||
category=ResourceCategory.AGENT,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
label="Agent Resource Type",
|
||||
name="agent_resource_type",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
options=[
|
||||
OptionValue(label=item.name, name=item.value, value=item.value)
|
||||
for item in ResourceType
|
||||
],
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Agent Resource Name",
|
||||
name="agent_resource_name",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The agent resource name.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Agent Resource Value",
|
||||
name="agent_resource_value",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The agent resource value.",
|
||||
),
|
||||
],
|
||||
alias=[
|
||||
"dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgentResource",
|
||||
"dbgpt.agent.plan.awel.agent_operator_resource.AWELAgentResource",
|
||||
],
|
||||
)
|
||||
class AWELAgentResource(AgentResource):
|
||||
"""AWEL Agent Resource."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the agent ResourceType."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
name = values.pop("agent_resource_name")
|
||||
type = values.pop("agent_resource_type")
|
||||
value = values.pop("agent_resource_value")
|
||||
|
||||
values["name"] = name
|
||||
values["type"] = ResourceType(type)
|
||||
values["value"] = value
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@register_resource(
|
||||
label="AWEL Agent LLM Config",
|
||||
name="agent_operator_llm_config",
|
||||
description="The Agent LLM Config.",
|
||||
category=ResourceCategory.AGENT,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"LLM Client",
|
||||
"llm_client",
|
||||
LLMClient,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The LLM Client.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Agent LLM Strategy",
|
||||
name="llm_strategy",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
options=[
|
||||
OptionValue(label=item.name, name=item.value, value=item.value)
|
||||
for item in LLMStrategyType
|
||||
],
|
||||
description="The Agent LLM Strategy.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Agent LLM Strategy Value",
|
||||
name="strategy_context",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The agent LLM Strategy Value.",
|
||||
),
|
||||
],
|
||||
alias=[
|
||||
"dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgentConfig",
|
||||
"dbgpt.agent.plan.awel.agent_operator_resource.AWELAgentConfig",
|
||||
],
|
||||
)
|
||||
class AWELAgentConfig(LLMConfig):
|
||||
"""AWEL Agent Config."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _agent_resource_option_values() -> List[OptionValue]:
|
||||
return [
|
||||
OptionValue(label=item["name"], name=item["name"], value=item["name"])
|
||||
for item in get_agent_manager().list_agents()
|
||||
]
|
||||
|
||||
|
||||
@register_resource(
|
||||
label="AWEL Layout Agent",
|
||||
name="agent_operator_agent",
|
||||
description="The Agent to build the Agent Operator.",
|
||||
category=ResourceCategory.AGENT,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
label="Agent Profile",
|
||||
name="agent_profile",
|
||||
type=str,
|
||||
description="Which agent want use.",
|
||||
options=FunctionDynamicOptions(func=_agent_resource_option_values),
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Role Name",
|
||||
name="role_name",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The agent role name.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Fixed Gogal",
|
||||
name="fixed_subgoal",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The agent fixed gogal.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Agent Resource",
|
||||
name="agent_resource",
|
||||
type=AWELAgentResource,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The agent resource.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Agent LLM Config",
|
||||
name="agent_llm_Config",
|
||||
type=AWELAgentConfig,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The agent llm config.",
|
||||
),
|
||||
],
|
||||
alias=[
|
||||
"dbgpt.serve.agent.team.layout.agent_operator_resource.AwelAgent",
|
||||
"dbgpt.agent.plan.awel.agent_operator_resource.AWELAgent",
|
||||
],
|
||||
)
|
||||
class AWELAgent(BaseModel):
|
||||
"""AWEL Agent."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
agent_profile: str
|
||||
role_name: Optional[str] = None
|
||||
llm_config: Optional[LLMConfig] = None
|
||||
resources: List[AgentResource] = Field(default_factory=list)
|
||||
fixed_subgoal: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the agent ResourceType."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
resource = values.pop("agent_resource")
|
||||
llm_config = values.pop("agent_llm_Config")
|
||||
|
||||
if resource is not None:
|
||||
values["resources"] = [resource]
|
||||
|
||||
if llm_config is not None:
|
||||
values["llm_config"] = llm_config
|
||||
|
||||
return values
|
268
dbgpt/agent/core/plan/awel/team_awel_layout.py
Normal file
268
dbgpt/agent/core/plan/awel/team_awel_layout.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""The manager of the team for the AWEL layout."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_to_dict,
|
||||
validator,
|
||||
)
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
|
||||
from ...action.base import ActionOutput
|
||||
from ...agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...base_team import ManagerAgent
|
||||
from ...profile import DynConfig, ProfileConfig
|
||||
from .agent_operator import AWELAgentOperator, WrappedAgentOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AWELTeamContext(BaseModel):
|
||||
"""The context of the team for the AWEL layout."""
|
||||
|
||||
dag_id: str = Field(
|
||||
...,
|
||||
description="The unique id of dag",
|
||||
examples=["flow_dag_testflow_66d8e9d6-f32e-4540-a5bd-ea0648145d0e"],
|
||||
)
|
||||
uid: str = Field(
|
||||
default=None,
|
||||
description="The unique id of flow",
|
||||
examples=["66d8e9d6-f32e-4540-a5bd-ea0648145d0e"],
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The name of dag",
|
||||
)
|
||||
label: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The label of dag",
|
||||
)
|
||||
version: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The version of dag",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The description of dag",
|
||||
)
|
||||
editable: bool = Field(
|
||||
default=False,
|
||||
description="is the dag is editable",
|
||||
examples=[True, False],
|
||||
)
|
||||
state: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The state of dag",
|
||||
)
|
||||
user_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The owner of current dag",
|
||||
)
|
||||
sys_code: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The system code of current dag",
|
||||
)
|
||||
flow_category: Optional[str] = Field(
|
||||
default="common",
|
||||
description="The flow category of current dag",
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert the object to a dictionary."""
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class AWELBaseManager(ManagerAgent, ABC):
|
||||
"""AWEL base manager."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="AWELBaseManager",
|
||||
role=DynConfig(
|
||||
"PlanManager", category="agent", key="dbgpt_agent_plan_awel_profile_name"
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Promote and solve user problems according to the process arranged "
|
||||
"by AWEL.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_awel_profile_goal",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"Promote and solve user problems according to the process arranged "
|
||||
"by AWEL.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_awel_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
async def _a_process_received_message(self, message: AgentMessage, sender: Agent):
|
||||
"""Process the received message."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dag(self) -> DAG:
|
||||
"""Get the DAG of the manager."""
|
||||
|
||||
async def act(
|
||||
self,
|
||||
message: Optional[str],
|
||||
sender: Optional[Agent] = None,
|
||||
reviewer: Optional[Agent] = None,
|
||||
**kwargs,
|
||||
) -> Optional[ActionOutput]:
|
||||
"""Perform the action."""
|
||||
try:
|
||||
agent_dag = self.get_dag()
|
||||
last_node: AWELAgentOperator = cast(
|
||||
AWELAgentOperator, agent_dag.leaf_nodes[0]
|
||||
)
|
||||
|
||||
start_message_context: AgentGenerateContext = AgentGenerateContext(
|
||||
message=AgentMessage(content=message, current_goal=message),
|
||||
sender=sender,
|
||||
reviewer=reviewer,
|
||||
memory=self.memory.structure_clone(),
|
||||
agent_context=self.agent_context,
|
||||
resource_loader=self.resource_loader,
|
||||
llm_client=self.not_null_llm_config.llm_client,
|
||||
)
|
||||
final_generate_context: AgentGenerateContext = await last_node.call(
|
||||
call_data=start_message_context
|
||||
)
|
||||
last_message = final_generate_context.rely_messages[-1]
|
||||
|
||||
last_agent = await last_node.get_agent(final_generate_context)
|
||||
if final_generate_context.round_index is not None:
|
||||
last_agent.consecutive_auto_reply_counter = (
|
||||
final_generate_context.round_index
|
||||
)
|
||||
if not sender:
|
||||
raise ValueError("sender is required!")
|
||||
await last_agent.send(
|
||||
last_message, sender, start_message_context.reviewer, False
|
||||
)
|
||||
|
||||
view_message: Optional[str] = None
|
||||
if last_message.action_report:
|
||||
view_message = last_message.action_report.get("view", None)
|
||||
|
||||
return ActionOutput(
|
||||
content=last_message.content,
|
||||
view=view_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"DAG run failed!{str(e)}")
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content=f"Failed to complete goal! {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
class WrappedAWELLayoutManager(AWELBaseManager):
|
||||
"""The manager of the team for the AWEL layout.
|
||||
|
||||
Receives a DAG or builds a DAG from the agents.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
dag: Optional[DAG] = Field(None, description="The DAG of the manager")
|
||||
|
||||
def get_dag(self) -> DAG:
|
||||
"""Get the DAG of the manager."""
|
||||
if self.dag:
|
||||
return self.dag
|
||||
conv_id = self.not_null_agent_context.conv_id
|
||||
last_node: Optional[WrappedAgentOperator] = None
|
||||
with DAG(
|
||||
f"layout_agents_{self.not_null_agent_context.gpts_app_name}_{conv_id}"
|
||||
) as dag:
|
||||
for agent in self.agents:
|
||||
now_node = WrappedAgentOperator(agent=agent)
|
||||
if not last_node:
|
||||
last_node = now_node
|
||||
else:
|
||||
last_node >> now_node
|
||||
last_node = now_node
|
||||
self.dag = dag
|
||||
return dag
|
||||
|
||||
async def act(
|
||||
self,
|
||||
message: Optional[str],
|
||||
sender: Optional[Agent] = None,
|
||||
reviewer: Optional[Agent] = None,
|
||||
**kwargs,
|
||||
) -> Optional[ActionOutput]:
|
||||
"""Perform the action."""
|
||||
try:
|
||||
dag = self.get_dag()
|
||||
last_node: WrappedAgentOperator = cast(
|
||||
WrappedAgentOperator, dag.leaf_nodes[0]
|
||||
)
|
||||
start_message_context: AgentGenerateContext = AgentGenerateContext(
|
||||
message=AgentMessage(content=message, current_goal=message),
|
||||
sender=self,
|
||||
reviewer=reviewer,
|
||||
)
|
||||
final_generate_context: AgentGenerateContext = await last_node.call(
|
||||
call_data=start_message_context
|
||||
)
|
||||
last_message = final_generate_context.rely_messages[-1]
|
||||
|
||||
last_agent = last_node.agent
|
||||
await last_agent.send(
|
||||
last_message,
|
||||
self,
|
||||
start_message_context.reviewer,
|
||||
False,
|
||||
)
|
||||
|
||||
view_message: Optional[str] = None
|
||||
if last_message.action_report:
|
||||
view_message = last_message.action_report.get("view", None)
|
||||
|
||||
return ActionOutput(
|
||||
content=last_message.content,
|
||||
view=view_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"DAG run failed!{str(e)}")
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content=f"Failed to complete goal! {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
class DefaultAWELLayoutManager(AWELBaseManager):
|
||||
"""The manager of the team for the AWEL layout."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
dag: AWELTeamContext = Field(...)
|
||||
|
||||
@validator("dag")
|
||||
def check_dag(cls, value):
|
||||
"""Check the DAG of the manager."""
|
||||
assert value is not None and value != "", "dag must not be empty"
|
||||
return value
|
||||
|
||||
def get_dag(self) -> DAG:
|
||||
"""Get the DAG of the manager."""
|
||||
cfg = Config()
|
||||
_dag_manager = DAGManager.get_instance(cfg.SYSTEM_APP) # type: ignore
|
||||
agent_dag: Optional[DAG] = _dag_manager.get_dag(alias_name=self.dag.uid)
|
||||
if agent_dag is None:
|
||||
raise ValueError(f"The configured flow cannot be found![{self.dag.name}]")
|
||||
return agent_dag
|
139
dbgpt/agent/core/plan/plan_action.py
Normal file
139
dbgpt/agent/core/plan/plan_action.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Plan Action."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.vis.tags.vis_agent_plans import Vis, VisAgentPlans
|
||||
|
||||
from ...resource.resource_api import AgentResource
|
||||
from ..action.base import Action, ActionOutput
|
||||
from ..agent import AgentContext
|
||||
from ..memory.gpts.base import GptsPlan
|
||||
from ..memory.gpts.gpts_memory import GptsPlansMemory
|
||||
from ..schema import Status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlanInput(BaseModel):
|
||||
"""Plan input model."""
|
||||
|
||||
serial_number: int = Field(
|
||||
0,
|
||||
description="Number of sub-tasks",
|
||||
)
|
||||
agent: str = Field(..., description="The agent name to complete current task")
|
||||
content: str = Field(
|
||||
...,
|
||||
description="The task content of current step, make sure it can by executed by"
|
||||
" agent",
|
||||
)
|
||||
rely: str = Field(
|
||||
...,
|
||||
description="The rely task number(serial_number), e.g. 1,2,3, empty if no rely",
|
||||
)
|
||||
|
||||
|
||||
class PlanAction(Action[List[PlanInput]]):
|
||||
"""Plan action class."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a plan action."""
|
||||
super().__init__()
|
||||
self._render_protocol = VisAgentPlans()
|
||||
|
||||
@property
|
||||
def render_protocol(self) -> Optional[Vis]:
|
||||
"""Return the render protocol."""
|
||||
return self._render_protocol
|
||||
|
||||
@property
|
||||
def out_model_type(self):
|
||||
"""Output model type."""
|
||||
return List[PlanInput]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
resource: Optional[AgentResource] = None,
|
||||
rely_action_out: Optional[ActionOutput] = None,
|
||||
need_vis_render: bool = True,
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
"""Run the plan action."""
|
||||
context: AgentContext = kwargs["context"]
|
||||
plans_memory: GptsPlansMemory = kwargs["plans_memory"]
|
||||
try:
|
||||
param: List[PlanInput] = self._input_convert(ai_message, List[PlanInput])
|
||||
except Exception as e:
|
||||
logger.exception((str(e)))
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content="The requested correctly structured answer could not be found.",
|
||||
)
|
||||
fail_reason = ""
|
||||
|
||||
try:
|
||||
response_success = True
|
||||
plan_objects = []
|
||||
try:
|
||||
for item in param:
|
||||
plan = GptsPlan(
|
||||
conv_id=context.conv_id,
|
||||
sub_task_num=item.serial_number,
|
||||
sub_task_content=item.content,
|
||||
)
|
||||
plan.resource_name = ""
|
||||
plan.max_retry_times = context.max_retry_round
|
||||
plan.sub_task_agent = item.agent
|
||||
plan.sub_task_title = item.content
|
||||
plan.rely = item.rely
|
||||
plan.retry_times = 0
|
||||
plan.state = Status.TODO.value
|
||||
plan_objects.append(plan)
|
||||
|
||||
plans_memory.remove_by_conv_id(context.conv_id)
|
||||
plans_memory.batch_save(plan_objects)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(str(e))
|
||||
fail_reason = (
|
||||
f"The generated plan cannot be stored, reason: {str(e)}."
|
||||
f" Please check whether it is a problem with the plan content. "
|
||||
f"If so, please regenerate the correct plan. If not, please return"
|
||||
f" 'TERMINATE'."
|
||||
)
|
||||
response_success = False
|
||||
|
||||
if response_success:
|
||||
plan_content = []
|
||||
mk_plans = []
|
||||
for item in param:
|
||||
plan_content.append(
|
||||
{
|
||||
"name": item.content,
|
||||
"num": item.serial_number,
|
||||
"status": Status.TODO.value,
|
||||
"agent": item.agent,
|
||||
"rely": item.rely,
|
||||
"markdown": "",
|
||||
}
|
||||
)
|
||||
mk_plans.append(
|
||||
f"- {item.serial_number}.{item.content}[{item.agent}]"
|
||||
)
|
||||
|
||||
view = "\n".join(mk_plans)
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=ai_message,
|
||||
view=view,
|
||||
)
|
||||
else:
|
||||
raise ValueError(fail_reason)
|
||||
except Exception as e:
|
||||
logger.exception("Plan Action Run Failed!")
|
||||
return ActionOutput(
|
||||
is_exe_success=False, content=f"Plan action run failed!{str(e)}"
|
||||
)
|
165
dbgpt/agent/core/plan/planner_agent.py
Normal file
165
dbgpt/agent/core/plan/planner_agent.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Planner Agent."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
|
||||
from ..agent import AgentMessage
|
||||
from ..base_agent import ConversableAgent
|
||||
from ..plan.plan_action import PlanAction
|
||||
from ..profile import DynConfig, ProfileConfig
|
||||
|
||||
|
||||
class PlannerAgent(ConversableAgent):
|
||||
"""Planner Agent.
|
||||
|
||||
Planner agent, realizing task goal planning decomposition through LLM.
|
||||
"""
|
||||
|
||||
agents: List[ConversableAgent] = Field(default_factory=list)
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"Planner",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"Planner",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Understand each of the following intelligent agents and their "
|
||||
"capabilities, using the provided resources, solve user problems by "
|
||||
"coordinating intelligent agents. Please utilize your LLM's knowledge "
|
||||
"and understanding ability to comprehend the intent and goals of the "
|
||||
"user's problem, generating a task plan that can be completed through"
|
||||
" the collaboration of intelligent agents without user assistance.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_goal",
|
||||
),
|
||||
expand_prompt=DynConfig(
|
||||
"Available Intelligent Agents:\n {{ agents }}",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_expand_prompt",
|
||||
),
|
||||
constraints=DynConfig(
|
||||
[
|
||||
"Every step of the task plan should exist to advance towards solving "
|
||||
"the user's goals. Do not generate meaningless task steps; ensure "
|
||||
"that each step has a clear goal and its content is complete.",
|
||||
"Pay attention to the dependencies and logic of each step in the task "
|
||||
"plan. For the steps that are depended upon, consider the data they "
|
||||
"depend on and whether it can be obtained based on the current goal. "
|
||||
"If it cannot be obtained, please indicate in the goal that the "
|
||||
"dependent data needs to be generated.",
|
||||
"Each step must be an independently achievable goal. Ensure that the "
|
||||
"logic and information are complete. Avoid steps with unclear "
|
||||
"objectives, like 'Analyze the retrieved issues data,' where it's "
|
||||
"unclear what specific content needs to be analyzed.",
|
||||
"Please ensure that only the intelligent agents mentioned above are "
|
||||
"used, and you may use only the necessary parts of them. Allocate "
|
||||
"them to appropriate steps strictly based on their described "
|
||||
"capabilities and limitations. Each intelligent agent can be reused.",
|
||||
"Utilize the provided resources to assist in generating the plan "
|
||||
"steps according to the actual needs of the user's goals. Do not use "
|
||||
"unnecessary resources.",
|
||||
"Each step should ideally use only one type of resource to accomplish "
|
||||
"a sub-goal. If the current goal can be broken down into multiple "
|
||||
"subtasks of the same type, you can create mutually independent "
|
||||
"parallel tasks.",
|
||||
"Data resources can be loaded and utilized by the appropriate "
|
||||
"intelligent agents without the need to consider the issues related "
|
||||
"to data loading links.",
|
||||
"Try to merge continuous steps that have sequential dependencies. If "
|
||||
"the user's goal does not require splitting, you can create a "
|
||||
"single-step task with content that is the user's goal.",
|
||||
"Carefully review the plan to ensure it comprehensively covers all "
|
||||
"information involved in the user's problem and can ultimately "
|
||||
"achieve the goal. Confirm whether each step includes the necessary "
|
||||
"resource information, such as URLs, resource names, etc.",
|
||||
],
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_constraints",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"You are a task planning expert! You can coordinate intelligent agents"
|
||||
" and allocate resources to achieve complex task goals.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_desc",
|
||||
),
|
||||
examples=DynConfig(
|
||||
"""
|
||||
user:help me build a sales report summarizing our key metrics and trends
|
||||
assistants:[
|
||||
{{
|
||||
"serial_number": "1",
|
||||
"agent": "DataScientist",
|
||||
"content": "Retrieve total sales, average sales, and number of transactions grouped by "product_category"'.",
|
||||
"rely": ""
|
||||
}},
|
||||
{{
|
||||
"serial_number": "2",
|
||||
"agent": "DataScientist",
|
||||
"content": "Retrieve monthly sales and transaction number trends.",
|
||||
"rely": ""
|
||||
}},
|
||||
{{
|
||||
"serial_number": "3",
|
||||
"agent": "Reporter",
|
||||
"content": "Integrate analytical data into the format required to build sales reports.",
|
||||
"rely": "1,2"
|
||||
}}
|
||||
]""", # noqa: E501
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_planner_agent_profile_examples",
|
||||
),
|
||||
)
|
||||
_goal_zh: str = (
|
||||
"理解下面每个智能体(agent)和他们的能力,使用给出的资源,通过协调智能体来解决"
|
||||
"用户问题。 请发挥你LLM的知识和理解能力,理解用户问题的意图和目标,生成一个可以在没有用户帮助"
|
||||
"下,由智能体协作完成目标的任务计划。"
|
||||
)
|
||||
_expand_prompt_zh: str = "可用智能体(agent):\n {{ agents }}"
|
||||
|
||||
_constraints_zh: List[str] = [
|
||||
"任务计划的每个步骤都应该是为了推进解决用户目标而存在,不要生成无意义的任务步骤,确保每个步骤内目标明确内容完整。",
|
||||
"关注任务计划每个步骤的依赖关系和逻辑,被依赖步骤要考虑被依赖的数据,是否能基于当前目标得到,如果不能请在目标中提示要生成被依赖数据。",
|
||||
"每个步骤都是一个独立可完成的目标,一定要确保逻辑和信息完整,不要出现类似:"
|
||||
"'Analyze the retrieved issues data'这样目标不明确,不知道具体要分析啥内容的步骤",
|
||||
"请确保只使用上面提到的智能体,并且可以只使用其中需要的部分,严格根据描述能力和限制分配给合适的步骤,每个智能体都可以重复使用。",
|
||||
"根据用户目标的实际需要使用提供的资源来协助生成计划步骤,不要使用不需要的资源。",
|
||||
"每个步骤最好只使用一种资源完成一个子目标,如果当前目标可以分解为同类型的多个子任务,可以生成相互不依赖的并行任务。",
|
||||
"数据资源可以被合适的智能体加载使用,不用考虑数据资源的加载链接问题",
|
||||
"尽量合并有顺序依赖的连续相同步骤,如果用户目标无拆分必要,可以生成内容为用户目标的单步任务。",
|
||||
"仔细检查计划,确保计划完整的包含了用户问题所涉及的所有信息,并且最终能完成目标,确认每个步骤是否包含了需要用到的资源信息,如URL、资源名等. ",
|
||||
]
|
||||
_desc_zh: str = "你是一个任务规划专家!可以协调智能体,分配资源完成复杂的任务目标。"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new PlannerAgent instance."""
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([PlanAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage):
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
reply_message.context = {
|
||||
"agents": "\n".join([f"- {item.role}:{item.desc}" for item in self.agents]),
|
||||
}
|
||||
return reply_message
|
||||
|
||||
def bind_agents(self, agents: List[ConversableAgent]) -> ConversableAgent:
|
||||
"""Bind the agents to the planner agent."""
|
||||
self.agents = agents
|
||||
for agent in self.agents:
|
||||
if agent.resources and len(agent.resources) > 0:
|
||||
self.resources.extend(agent.resources)
|
||||
return self
|
||||
|
||||
def prepare_act_param(self) -> Dict[str, Any]:
|
||||
"""Prepare the parameters for the act method."""
|
||||
return {
|
||||
"context": self.not_null_agent_context,
|
||||
"plans_memory": self.memory.plans_memory,
|
||||
}
|
312
dbgpt/agent/core/plan/team_auto_plan.py
Normal file
312
dbgpt/agent/core/plan/team_auto_plan.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Auto plan chat manager agent."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core.interface.message import ModelMessageRoleType
|
||||
|
||||
from ..action.base import ActionOutput
|
||||
from ..agent import Agent, AgentMessage
|
||||
from ..agent_manage import mentioned_agents, participant_roles
|
||||
from ..base_agent import ConversableAgent
|
||||
from ..base_team import ManagerAgent
|
||||
from ..memory.gpts.base import GptsPlan
|
||||
from ..plan.planner_agent import PlannerAgent
|
||||
from ..profile import DynConfig, ProfileConfig
|
||||
from ..schema import Status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoPlanChatManager(ManagerAgent):
|
||||
"""A chat manager agent that can manage a team chat of multiple agents."""
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
"AutoPlanChatManager",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_name",
|
||||
),
|
||||
role=DynConfig(
|
||||
"PlanManager",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_role",
|
||||
),
|
||||
goal=DynConfig(
|
||||
"Advance the task plan generated by the planning agent. If the plan "
|
||||
"does not pre-allocate an agent, it needs to be coordinated with the "
|
||||
"appropriate agent to complete.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_goal",
|
||||
),
|
||||
desc=DynConfig(
|
||||
"Advance the task plan generated by the planning agent.",
|
||||
category="agent",
|
||||
key="dbgpt_agent_plan_team_auto_plan_profile_desc",
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new AutoPlanChatManager instance."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_rely_message(
|
||||
self, conv_id: str, now_plan: GptsPlan, speaker: Agent
|
||||
):
|
||||
"""Process the dependent message."""
|
||||
rely_prompt = None
|
||||
rely_messages: List[Dict] = []
|
||||
|
||||
if now_plan.rely and len(now_plan.rely) > 0:
|
||||
rely_tasks_list = now_plan.rely.split(",")
|
||||
rely_tasks_list_int = [int(i) for i in rely_tasks_list]
|
||||
rely_tasks = self.memory.plans_memory.get_by_conv_id_and_num(
|
||||
conv_id, rely_tasks_list_int
|
||||
)
|
||||
if rely_tasks:
|
||||
rely_prompt = (
|
||||
"Read the result data of the dependent steps in the above"
|
||||
" historical message to complete the current goal:"
|
||||
)
|
||||
for rely_task in rely_tasks:
|
||||
rely_messages.append(
|
||||
{
|
||||
"content": rely_task.sub_task_content,
|
||||
"role": ModelMessageRoleType.HUMAN,
|
||||
"name": rely_task.sub_task_agent,
|
||||
}
|
||||
)
|
||||
rely_messages.append(
|
||||
{
|
||||
"content": rely_task.result,
|
||||
"role": ModelMessageRoleType.AI,
|
||||
"name": rely_task.sub_task_agent,
|
||||
}
|
||||
)
|
||||
return rely_prompt, rely_messages
|
||||
|
||||
def select_speaker_msg(self, agents: List[Agent]) -> str:
|
||||
"""Return the message for selecting the next speaker."""
|
||||
agent_names = [agent.name for agent in agents]
|
||||
return (
|
||||
"You are in a role play game. The following roles are available:\n"
|
||||
f" {participant_roles(agents)}.\n"
|
||||
" Read the following conversation.\n"
|
||||
f" Then select the next role from {agent_names} to play.\n"
|
||||
" The role can be selected repeatedly.Only return the role."
|
||||
)
|
||||
|
||||
async def select_speaker(
|
||||
self,
|
||||
last_speaker: Agent,
|
||||
selector: Agent,
|
||||
now_goal_context: Optional[str] = None,
|
||||
pre_allocated: Optional[str] = None,
|
||||
) -> Tuple[Agent, Optional[str]]:
|
||||
"""Select the next speaker."""
|
||||
agents = self.agents
|
||||
|
||||
if pre_allocated:
|
||||
# Preselect speakers
|
||||
logger.info(f"Preselect speakers:{pre_allocated}")
|
||||
name = pre_allocated
|
||||
model = None
|
||||
else:
|
||||
# auto speaker selection
|
||||
# TODO selector a_thinking It has been overwritten and cannot be used.
|
||||
agent_names = [agent.name for agent in agents]
|
||||
fina_name, model = await selector.thinking(
|
||||
messages=[
|
||||
AgentMessage(
|
||||
role=ModelMessageRoleType.HUMAN,
|
||||
content="Read and understand the following task content and"
|
||||
" assign the appropriate role to complete the task.\n"
|
||||
f"Task content: {now_goal_context},\n"
|
||||
f"Select the role from: {agent_names},\n"
|
||||
f"Please only return the role, such as: {agents[0].name}",
|
||||
)
|
||||
],
|
||||
prompt=self.select_speaker_msg(agents),
|
||||
)
|
||||
if not fina_name:
|
||||
raise ValueError("Unable to select next speaker!")
|
||||
else:
|
||||
name = fina_name
|
||||
|
||||
# If exactly one agent is mentioned, use it. Otherwise, leave the OAI response
|
||||
# unmodified
|
||||
mentions = mentioned_agents(name, agents)
|
||||
if len(mentions) == 1:
|
||||
name = next(iter(mentions))
|
||||
else:
|
||||
logger.warning(
|
||||
"GroupChat select_speaker failed to resolve the next speaker's name. "
|
||||
f"This is because the speaker selection OAI call returned:\n{name}"
|
||||
)
|
||||
|
||||
# Return the result
|
||||
try:
|
||||
return self.agent_by_name(name), model
|
||||
except Exception as e:
|
||||
logger.exception(f"auto select speaker failed!{str(e)}")
|
||||
raise ValueError("Unable to select next speaker!")
|
||||
|
||||
async def act(
|
||||
self,
|
||||
message: Optional[str],
|
||||
sender: Optional[Agent] = None,
|
||||
reviewer: Optional[Agent] = None,
|
||||
**kwargs,
|
||||
) -> Optional[ActionOutput]:
|
||||
"""Perform an action based on the received message."""
|
||||
if not sender:
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content="The sender cannot be empty!",
|
||||
)
|
||||
speaker: Agent = sender
|
||||
final_message = message
|
||||
for i in range(self.max_round):
|
||||
if not self.memory:
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content="The memory cannot be empty!",
|
||||
)
|
||||
plans = self.memory.plans_memory.get_by_conv_id(
|
||||
self.not_null_agent_context.conv_id
|
||||
)
|
||||
|
||||
if not plans or len(plans) <= 0:
|
||||
if i > 3:
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content="Retrying 3 times based on current application "
|
||||
"resources still fails to build a valid plan!",
|
||||
)
|
||||
planner: ConversableAgent = (
|
||||
await PlannerAgent()
|
||||
.bind(self.memory)
|
||||
.bind(self.agent_context)
|
||||
.bind(self.llm_config)
|
||||
.bind(self.resource_loader)
|
||||
.bind_agents(self.agents)
|
||||
.build()
|
||||
)
|
||||
|
||||
plan_message = await planner.generate_reply(
|
||||
received_message=AgentMessage.from_llm_message(
|
||||
{"content": message}
|
||||
),
|
||||
sender=self,
|
||||
reviewer=reviewer,
|
||||
)
|
||||
await planner.send(
|
||||
message=plan_message, recipient=self, request_reply=False
|
||||
)
|
||||
else:
|
||||
todo_plans = [
|
||||
plan
|
||||
for plan in plans
|
||||
if plan.state in [Status.TODO.value, Status.RETRYING.value]
|
||||
]
|
||||
if not todo_plans or len(todo_plans) <= 0:
|
||||
# The plan has been fully executed and a success message is sent
|
||||
# to the user.
|
||||
# complete
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=final_message, # work results message
|
||||
)
|
||||
else:
|
||||
try:
|
||||
now_plan: GptsPlan = todo_plans[0]
|
||||
current_goal_message = AgentMessage(
|
||||
content=now_plan.sub_task_content,
|
||||
current_goal=now_plan.sub_task_content,
|
||||
context={
|
||||
"plan_task": now_plan.sub_task_content,
|
||||
"plan_task_num": now_plan.sub_task_num,
|
||||
},
|
||||
)
|
||||
# select the next speaker
|
||||
speaker, model = await self.select_speaker(
|
||||
speaker,
|
||||
self,
|
||||
now_plan.sub_task_content,
|
||||
now_plan.sub_task_agent,
|
||||
)
|
||||
# Tell the speaker the dependent history information
|
||||
rely_prompt, rely_messages = await self.process_rely_message(
|
||||
conv_id=self.not_null_agent_context.conv_id,
|
||||
now_plan=now_plan,
|
||||
speaker=speaker,
|
||||
)
|
||||
if rely_prompt:
|
||||
current_goal_message.content = (
|
||||
rely_prompt + current_goal_message.content
|
||||
)
|
||||
|
||||
await self.send(
|
||||
message=current_goal_message,
|
||||
recipient=speaker,
|
||||
reviewer=reviewer,
|
||||
request_reply=False,
|
||||
)
|
||||
agent_reply_message = await speaker.generate_reply(
|
||||
received_message=current_goal_message,
|
||||
sender=self,
|
||||
reviewer=reviewer,
|
||||
rely_messages=AgentMessage.from_messages(rely_messages),
|
||||
)
|
||||
is_success = agent_reply_message.success
|
||||
reply_message = agent_reply_message.to_llm_message()
|
||||
await speaker.send(
|
||||
agent_reply_message, self, reviewer, request_reply=False
|
||||
)
|
||||
|
||||
plan_result = ""
|
||||
final_message = reply_message["content"]
|
||||
if is_success:
|
||||
if reply_message:
|
||||
action_report = agent_reply_message.action_report
|
||||
if action_report:
|
||||
plan_result = action_report.get("content", "")
|
||||
final_message = action_report["view"]
|
||||
|
||||
# The current planned Agent generation verification is
|
||||
# successful
|
||||
# Plan executed successfully
|
||||
self.memory.plans_memory.complete_task(
|
||||
self.not_null_agent_context.conv_id,
|
||||
now_plan.sub_task_num,
|
||||
plan_result,
|
||||
)
|
||||
else:
|
||||
plan_result = reply_message["content"]
|
||||
self.memory.plans_memory.update_task(
|
||||
self.not_null_agent_context.conv_id,
|
||||
now_plan.sub_task_num,
|
||||
Status.FAILED.value,
|
||||
now_plan.retry_times + 1,
|
||||
speaker.name,
|
||||
"",
|
||||
plan_result,
|
||||
)
|
||||
return ActionOutput(
|
||||
is_exe_success=False, content=plan_result
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"An exception was encountered during the execution of the"
|
||||
f" current plan step.{str(e)}"
|
||||
)
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content=f"An exception was encountered during the execution"
|
||||
f" of the current plan step.{str(e)}",
|
||||
)
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content=f"Maximum number of dialogue rounds exceeded.{self.max_round}",
|
||||
)
|
31
dbgpt/agent/core/profile/__init__.py
Normal file
31
dbgpt/agent/core/profile/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Profiling module.
|
||||
|
||||
Autonomous agents typically perform tasks by assuming specific roles, such as coders,
|
||||
teachers and domain experts.
|
||||
|
||||
The profiling module aims to indicate the profiles of the agent roles, which are usually
|
||||
written into the prompt to influence the LLM behaviors.
|
||||
|
||||
Agent profiles typically encompass basic information such as age, gender, and career,
|
||||
as well as psychology information, reflecting the personalities of the agent, and social
|
||||
information, detailing the relationships between agents.
|
||||
|
||||
The choice of analysis information depends heavily on the application scenario.
|
||||
|
||||
How to create a profile:
|
||||
1. Handcrafting method
|
||||
2. LLM-generation method
|
||||
3. Dataset alignment method
|
||||
"""
|
||||
|
||||
from dbgpt.util.configure import DynConfig # noqa: F401
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
CompositeProfileFactory,
|
||||
DatasetProfileFactory,
|
||||
DefaultProfile,
|
||||
LLMProfileFactory,
|
||||
Profile,
|
||||
ProfileConfig,
|
||||
ProfileFactory,
|
||||
)
|
413
dbgpt/agent/core/profile/base.py
Normal file
413
dbgpt/agent/core/profile/base.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""Profile module."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import cachetools
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.util.configure import ConfigInfo, DynConfig
|
||||
|
||||
VALID_TEMPLATE_KEYS = {
|
||||
"role",
|
||||
"name",
|
||||
"goal",
|
||||
"resource_prompt",
|
||||
"expand_prompt",
|
||||
"language",
|
||||
"constraints",
|
||||
"examples",
|
||||
"out_schema",
|
||||
"most_recent_memories",
|
||||
"question",
|
||||
}
|
||||
|
||||
_DEFAULT_SYSTEM_TEMPLATE = """
|
||||
You are a {{ role }}, {% if name %}named {{ name }}, {% endif %}your goal is {{ goal }}.
|
||||
Please think step by step to achieve the goal. You can use the resources given below.
|
||||
At the same time, please strictly abide by the constraints and specifications in IMPORTANT REMINDER.
|
||||
{% if resource_prompt %} {{ resource_prompt }} {% endif %}
|
||||
{% if expand_prompt %} {{ expand_prompt }} {% endif %}
|
||||
|
||||
*** IMPORTANT REMINDER ***
|
||||
{% if language == 'zh' %}
|
||||
Please answer in simplified Chinese.
|
||||
{% else %}
|
||||
Please answer in English.
|
||||
{% endif %}
|
||||
|
||||
{% if constraints %}
|
||||
{% for constraint in constraints %}
|
||||
{{ loop.index }}. {{ constraint }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if examples %}
|
||||
You can refer to the following examples:
|
||||
{{ examples }}
|
||||
{% endif %}
|
||||
|
||||
{% if out_schema %} {{ out_schema }} {% endif %}
|
||||
""" # noqa
|
||||
|
||||
_DEFAULT_USER_TEMPLATE = """
|
||||
{% if most_recent_memories %}
|
||||
Most recent observations:
|
||||
{{ most_recent_memories }}
|
||||
{% endif %}
|
||||
|
||||
{% if question %}
|
||||
Question: {{ question }}
|
||||
{% endif %}
|
||||
"""
|
||||
|
||||
_DEFAULT_SAVE_MEMORY_TEMPLATE = """
|
||||
{% if question %}Question: {{ question }} {% endif %}
|
||||
{% if thought %}Thought: {{ thought }} {% endif %}
|
||||
{% if action %}Action: {{ action }} {% endif %}
|
||||
{% if observation %}Observation: {{ observation }} {% endif %}
|
||||
"""
|
||||
|
||||
|
||||
class Profile(ABC):
|
||||
"""Profile interface."""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Return the name of current agent."""
|
||||
|
||||
@abstractmethod
|
||||
def get_role(self) -> str:
|
||||
"""Return the role of current agent."""
|
||||
|
||||
def get_goal(self) -> Optional[str]:
|
||||
"""Return the goal of current agent."""
|
||||
return None
|
||||
|
||||
def get_constraints(self) -> Optional[List[str]]:
|
||||
"""Return the constraints of current agent."""
|
||||
return None
|
||||
|
||||
def get_description(self) -> Optional[str]:
|
||||
"""Return the description of current agent.
|
||||
|
||||
It will not be used to generate prompt.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_expand_prompt(self) -> Optional[str]:
|
||||
"""Return the expand prompt of current agent."""
|
||||
return None
|
||||
|
||||
def get_examples(self) -> Optional[str]:
|
||||
"""Return the examples of current agent."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_system_prompt_template(self) -> str:
|
||||
"""Return the prompt template of current agent."""
|
||||
|
||||
@abstractmethod
|
||||
def get_user_prompt_template(self) -> str:
|
||||
"""Return the user prompt template of current agent."""
|
||||
|
||||
@abstractmethod
|
||||
def get_save_memory_template(self) -> str:
|
||||
"""Return the save memory template of current agent."""
|
||||
|
||||
|
||||
class DefaultProfile(BaseModel, Profile):
|
||||
"""Default profile."""
|
||||
|
||||
name: str = Field("", description="The name of the agent.")
|
||||
role: str = Field("", description="The role of the agent.")
|
||||
goal: Optional[str] = Field(None, description="The goal of the agent.")
|
||||
constraints: Optional[List[str]] = Field(
|
||||
None, description="The constraints of the agent."
|
||||
)
|
||||
|
||||
desc: Optional[str] = Field(
|
||||
None, description="The description of the agent, not used to generate prompt."
|
||||
)
|
||||
|
||||
expand_prompt: Optional[str] = Field(
|
||||
None, description="The expand prompt of the agent."
|
||||
)
|
||||
|
||||
examples: Optional[str] = Field(
|
||||
None, description="The examples of the agent prompt."
|
||||
)
|
||||
|
||||
system_prompt_template: str = Field(
|
||||
_DEFAULT_SYSTEM_TEMPLATE, description="The system prompt template of the agent."
|
||||
)
|
||||
user_prompt_template: str = Field(
|
||||
_DEFAULT_USER_TEMPLATE, description="The user prompt template of the agent."
|
||||
)
|
||||
|
||||
save_memory_template: str = Field(
|
||||
_DEFAULT_SAVE_MEMORY_TEMPLATE,
|
||||
description="The save memory template of the agent.",
|
||||
)
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Return the name of current agent."""
|
||||
return self.name
|
||||
|
||||
def get_role(self) -> str:
|
||||
"""Return the role of current agent."""
|
||||
return self.role
|
||||
|
||||
def get_goal(self) -> Optional[str]:
|
||||
"""Return the goal of current agent."""
|
||||
return self.goal
|
||||
|
||||
def get_constraints(self) -> Optional[List[str]]:
|
||||
"""Return the constraints of current agent."""
|
||||
return self.constraints
|
||||
|
||||
def get_description(self) -> Optional[str]:
|
||||
"""Return the description of current agent.
|
||||
|
||||
It will not be used to generate prompt.
|
||||
"""
|
||||
return self.desc
|
||||
|
||||
def get_expand_prompt(self) -> Optional[str]:
|
||||
"""Return the expand prompt of current agent."""
|
||||
return self.expand_prompt
|
||||
|
||||
def get_examples(self) -> Optional[str]:
|
||||
"""Return the examples of current agent."""
|
||||
return self.examples
|
||||
|
||||
def get_system_prompt_template(self) -> str:
|
||||
"""Return the prompt template of current agent."""
|
||||
return self.system_prompt_template
|
||||
|
||||
def get_user_prompt_template(self) -> str:
|
||||
"""Return the user prompt template of current agent."""
|
||||
return self.user_prompt_template
|
||||
|
||||
def get_save_memory_template(self) -> str:
|
||||
"""Return the save memory template of current agent."""
|
||||
return self.save_memory_template
|
||||
|
||||
|
||||
class ProfileFactory:
|
||||
"""Profile factory interface.
|
||||
|
||||
It is used to create a profile.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile."""
|
||||
|
||||
|
||||
class LLMProfileFactory(ProfileFactory):
|
||||
"""Create a profile by LLM.
|
||||
|
||||
Based on LLM automatic generation, it usually specifies the rules of the generation
|
||||
configuration first, clarifies the composition and attributes of the agent
|
||||
configuration in the target population, and then gives a small number of samples,
|
||||
and finally LLM generates the configuration of all agents.
|
||||
"""
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile by LLM.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetProfileFactory(ProfileFactory):
|
||||
"""Create a profile by dataset.
|
||||
|
||||
Use existing data sets to generate agent configurations.
|
||||
|
||||
In some cases, the data set contains a large amount of information about real people
|
||||
, first organize the information about real people in the data set into a natural
|
||||
language prompt, which is then used to generate the agent configuration.
|
||||
"""
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile by dataset.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CompositeProfileFactory(ProfileFactory):
|
||||
"""Create a profile by combining multiple profile factories."""
|
||||
|
||||
def __init__(self, factories: List[ProfileFactory]):
|
||||
"""Create a composite profile factory."""
|
||||
self.factories = factories
|
||||
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: int,
|
||||
name: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Optional[Profile]:
|
||||
"""Create a profile by combining multiple profile factories.
|
||||
|
||||
TODO: Implement this method.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProfileConfig(BaseModel):
|
||||
"""Profile configuration.
|
||||
|
||||
If factory is not specified, name and role must be specified.
|
||||
If factory is specified and name and role are also specified, the factory will be
|
||||
preferred.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile_id: int = Field(0, description="The profile ID.")
|
||||
name: str | ConfigInfo | None = DynConfig(..., description="The name of the agent.")
|
||||
role: str | ConfigInfo | None = DynConfig(..., description="The role of the agent.")
|
||||
goal: str | ConfigInfo | None = DynConfig(None, description="The goal.")
|
||||
constraints: List[str] | ConfigInfo | None = DynConfig(None, is_list=True)
|
||||
desc: str | ConfigInfo | None = DynConfig(
|
||||
None, description="The description of the agent."
|
||||
)
|
||||
expand_prompt: str | ConfigInfo | None = DynConfig(
|
||||
None, description="The expand prompt."
|
||||
)
|
||||
examples: str | ConfigInfo | None = DynConfig(None, description="The examples.")
|
||||
|
||||
system_prompt_template: str | ConfigInfo | None = DynConfig(
|
||||
_DEFAULT_SYSTEM_TEMPLATE, description="The prompt template."
|
||||
)
|
||||
user_prompt_template: str | ConfigInfo | None = DynConfig(
|
||||
_DEFAULT_USER_TEMPLATE, description="The user prompt template."
|
||||
)
|
||||
save_memory_template: str | ConfigInfo | None = DynConfig(
|
||||
_DEFAULT_SAVE_MEMORY_TEMPLATE, description="The save memory template."
|
||||
)
|
||||
factory: ProfileFactory | None = Field(None, description="The profile factory.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_before(cls, values):
|
||||
"""Check before validation."""
|
||||
if isinstance(values, dict):
|
||||
return values
|
||||
if values["factory"] is None:
|
||||
if values["name"] is None:
|
||||
raise ValueError("name must be specified if factory is not specified")
|
||||
if values["role"] is None:
|
||||
raise ValueError("role must be specified if factory is not specified")
|
||||
return values
|
||||
|
||||
@cachetools.cached(cachetools.TTLCache(maxsize=100, ttl=10))
|
||||
def create_profile(
|
||||
self,
|
||||
profile_id: Optional[int] = None,
|
||||
prefer_prompt_language: Optional[str] = None,
|
||||
prefer_model: Optional[str] = None,
|
||||
) -> Profile:
|
||||
"""Create a profile.
|
||||
|
||||
If factory is specified, use the factory to create the profile.
|
||||
"""
|
||||
factory_profile = None
|
||||
if profile_id is None:
|
||||
profile_id = self.profile_id
|
||||
name = self.name
|
||||
role = self.role
|
||||
goal = self.goal
|
||||
constraints = self.constraints
|
||||
desc = self.desc
|
||||
expand_prompt = self.expand_prompt
|
||||
system_prompt_template = self.system_prompt_template
|
||||
user_prompt_template = self.user_prompt_template
|
||||
save_memory_template = self.save_memory_template
|
||||
examples = self.examples
|
||||
call_args = {
|
||||
"prefer_prompt_language": prefer_prompt_language,
|
||||
"prefer_model": prefer_model,
|
||||
}
|
||||
if isinstance(name, ConfigInfo):
|
||||
name = name.query(**call_args)
|
||||
if isinstance(role, ConfigInfo):
|
||||
role = role.query(**call_args)
|
||||
if isinstance(goal, ConfigInfo):
|
||||
goal = goal.query(**call_args)
|
||||
if isinstance(constraints, ConfigInfo):
|
||||
constraints = constraints.query(**call_args)
|
||||
if isinstance(desc, ConfigInfo):
|
||||
desc = desc.query(**call_args)
|
||||
if isinstance(expand_prompt, ConfigInfo):
|
||||
expand_prompt = expand_prompt.query(**call_args)
|
||||
if isinstance(examples, ConfigInfo):
|
||||
examples = examples.query(**call_args)
|
||||
if isinstance(system_prompt_template, ConfigInfo):
|
||||
system_prompt_template = system_prompt_template.query(**call_args)
|
||||
if isinstance(user_prompt_template, ConfigInfo):
|
||||
user_prompt_template = user_prompt_template.query(**call_args)
|
||||
if isinstance(save_memory_template, ConfigInfo):
|
||||
save_memory_template = save_memory_template.query(**call_args)
|
||||
|
||||
if self.factory is not None:
|
||||
factory_profile = self.factory.create_profile(
|
||||
profile_id,
|
||||
name,
|
||||
role,
|
||||
goal,
|
||||
prefer_prompt_language,
|
||||
prefer_model,
|
||||
)
|
||||
|
||||
if factory_profile is not None:
|
||||
return factory_profile
|
||||
return DefaultProfile(
|
||||
name=name,
|
||||
role=role,
|
||||
goal=goal,
|
||||
constraints=constraints,
|
||||
desc=desc,
|
||||
expand_prompt=expand_prompt,
|
||||
examples=examples,
|
||||
system_prompt_template=system_prompt_template,
|
||||
user_prompt_template=user_prompt_template,
|
||||
save_memory_template=save_memory_template,
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
"""Return the hash value."""
|
||||
return hash(self.profile_id)
|
@@ -1,112 +1,93 @@
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from jinja2.meta import find_undeclared_variables
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .action.base import ActionOutput
|
||||
from .memory.agent_memory import AgentMemory, AgentMemoryFragment
|
||||
from .memory.llm import LLMImportanceScorer, LLMInsightExtractor
|
||||
from .profile import Profile, ProfileConfig
|
||||
|
||||
|
||||
class Role(ABC, BaseModel):
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = ""
|
||||
name: str = ""
|
||||
resource_introduction: str = ""
|
||||
goal: str = ""
|
||||
|
||||
expand_prompt: str = ""
|
||||
profile: ProfileConfig = Field(
|
||||
...,
|
||||
description="The profile of the role.",
|
||||
)
|
||||
memory: AgentMemory = Field(default_factory=AgentMemory)
|
||||
|
||||
fixed_subgoal: Optional[str] = Field(None, description="Fixed subgoal")
|
||||
|
||||
constraints: List[str] = Field(default_factory=list, description="Constraints")
|
||||
examples: str = ""
|
||||
desc: str = ""
|
||||
language: str = "en"
|
||||
is_human: bool = False
|
||||
is_team: bool = False
|
||||
|
||||
def prompt_template(
|
||||
template_env: SandboxedEnvironment = Field(default_factory=SandboxedEnvironment)
|
||||
|
||||
async def build_prompt(
|
||||
self,
|
||||
specified_prompt: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
is_system: bool = True,
|
||||
most_recent_memories: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""Return the prompt template for the role.
|
||||
|
||||
Args:
|
||||
specified_prompt (str, optional): The specified prompt. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The prompt template.
|
||||
"""
|
||||
if specified_prompt:
|
||||
return specified_prompt
|
||||
|
||||
expand_prompt = self.expand_prompt if len(self.expand_prompt) > 0 else ""
|
||||
examples_prompt = (
|
||||
"You can refer to the following examples:\n"
|
||||
if len(self.examples) > 0
|
||||
else ""
|
||||
prompt_template = (
|
||||
self.system_prompt_template if is_system else self.user_prompt_template
|
||||
)
|
||||
examples = self.examples if len(self.examples) > 0 else ""
|
||||
template = (
|
||||
f"{self.role_prompt}\n"
|
||||
"Please think step by step to achieve the goal. You can use the resources "
|
||||
"given below. At the same time, please strictly abide by the constraints "
|
||||
"and specifications in IMPORTANT REMINDER.\n\n"
|
||||
f"{{resource_prompt}}\n\n"
|
||||
f"{expand_prompt}\n\n"
|
||||
"*** IMPORTANT REMINDER ***\n"
|
||||
f"{self.language_require_prompt}\n"
|
||||
f"{self.constraints_prompt}\n"
|
||||
f"{examples_prompt}{examples}\n\n"
|
||||
f"{{out_schema}}"
|
||||
)
|
||||
return template
|
||||
template_vars = self._get_template_variables(prompt_template)
|
||||
_sub_render_keys = {"role", "name", "goal", "expand_prompt", "constraints"}
|
||||
pass_vars = {
|
||||
"role": self.role,
|
||||
"name": self.name,
|
||||
"goal": self.goal,
|
||||
"expand_prompt": self.expand_prompt,
|
||||
"language": self.language,
|
||||
"constraints": self.constraints,
|
||||
"most_recent_memories": (
|
||||
most_recent_memories if most_recent_memories else None
|
||||
),
|
||||
"examples": self.examples,
|
||||
# "out_schema": out_schema if out_schema else None,
|
||||
# "resource_prompt": resource_prompt if resource_prompt else None,
|
||||
"question": question,
|
||||
}
|
||||
resource_vars = await self.generate_resource_variables(question)
|
||||
pass_vars.update(resource_vars)
|
||||
pass_vars.update(kwargs)
|
||||
filtered_data = {
|
||||
key: pass_vars[key] for key in template_vars if key in pass_vars
|
||||
}
|
||||
for key in filtered_data.keys():
|
||||
value = filtered_data[key]
|
||||
if key in _sub_render_keys and value:
|
||||
if isinstance(value, str):
|
||||
# Render the sub-template
|
||||
filtered_data[key] = self._render_template(value, **pass_vars)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, str):
|
||||
value[i] = self._render_template(item, **pass_vars)
|
||||
return self._render_template(prompt_template, **filtered_data)
|
||||
|
||||
@property
|
||||
def role_prompt(self) -> str:
|
||||
"""Return the role prompt.
|
||||
|
||||
You are a {self.profile}, named {self.name}, your goal is {self.goal}.
|
||||
|
||||
Returns:
|
||||
str: The role prompt.
|
||||
"""
|
||||
profile_prompt = f"You are a {self.profile},"
|
||||
name_prompt = f"named {self.name}," if len(self.name) > 0 else ""
|
||||
goal_prompt = f"your goal is {self.goal}"
|
||||
prompt = f"""{profile_prompt}{name_prompt}{goal_prompt}"""
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def constraints_prompt(self) -> str:
|
||||
"""Return the constraints prompt.
|
||||
|
||||
Return:
|
||||
str: The constraints prompt.
|
||||
"""
|
||||
if len(self.constraints) > 0:
|
||||
return "\n".join(
|
||||
f"{i + 1}. {item}" for i, item in enumerate(self.constraints)
|
||||
)
|
||||
return ""
|
||||
|
||||
@property
|
||||
def language_require_prompt(self) -> str:
|
||||
"""Return the language requirement prompt.
|
||||
|
||||
Returns:
|
||||
str: The language requirement prompt.
|
||||
"""
|
||||
if self.language == "zh":
|
||||
return "Please answer in simplified Chinese."
|
||||
else:
|
||||
return "Please answer in English."
|
||||
|
||||
@property
|
||||
def introduce(self) -> str:
|
||||
"""Introduce the role."""
|
||||
return self.desc
|
||||
async def generate_resource_variables(
|
||||
self, question: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate the resource variables."""
|
||||
return {}
|
||||
|
||||
def identity_check(self) -> None:
|
||||
"""Check the identity of the role."""
|
||||
@@ -114,12 +95,123 @@ class Role(ABC, BaseModel):
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Get the name of the role."""
|
||||
return self.name
|
||||
return self.current_profile.get_name()
|
||||
|
||||
def get_profile(self) -> str:
|
||||
"""Get the profile of the role."""
|
||||
return self.profile
|
||||
@property
|
||||
def current_profile(self) -> Profile:
|
||||
"""Return the current profile."""
|
||||
profile = self.profile.create_profile()
|
||||
return profile
|
||||
|
||||
def get_describe(self) -> str:
|
||||
"""Get the describe of the role."""
|
||||
return self.desc
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of the role."""
|
||||
return self.current_profile.get_name()
|
||||
|
||||
@property
|
||||
def examples(self) -> Optional[str]:
|
||||
"""Return the examples of the role."""
|
||||
return self.current_profile.get_examples()
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
"""Return the role of the role."""
|
||||
return self.current_profile.get_role()
|
||||
|
||||
@property
|
||||
def goal(self) -> Optional[str]:
|
||||
"""Return the goal of the role."""
|
||||
return self.current_profile.get_goal()
|
||||
|
||||
@property
|
||||
def constraints(self) -> Optional[List[str]]:
|
||||
"""Return the constraints of the role."""
|
||||
return self.current_profile.get_constraints()
|
||||
|
||||
@property
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Return the description of the role."""
|
||||
return self.current_profile.get_description()
|
||||
|
||||
@property
|
||||
def expand_prompt(self) -> Optional[str]:
|
||||
"""Return the expand prompt of the role."""
|
||||
return self.current_profile.get_expand_prompt()
|
||||
|
||||
@property
|
||||
def system_prompt_template(self) -> str:
|
||||
"""Return the current system prompt template."""
|
||||
return self.current_profile.get_system_prompt_template()
|
||||
|
||||
@property
|
||||
def user_prompt_template(self) -> str:
|
||||
"""Return the current user prompt template."""
|
||||
return self.current_profile.get_user_prompt_template()
|
||||
|
||||
@property
|
||||
def save_memory_template(self) -> str:
|
||||
"""Return the current save memory template."""
|
||||
return self.current_profile.get_save_memory_template()
|
||||
|
||||
def _get_template_variables(self, template: str) -> Set[str]:
|
||||
parsed_content = self.template_env.parse(template)
|
||||
return find_undeclared_variables(parsed_content)
|
||||
|
||||
def _render_template(self, template: str, **kwargs):
|
||||
r_template = self.template_env.from_string(template)
|
||||
return r_template.render(**kwargs)
|
||||
|
||||
@property
|
||||
def memory_importance_scorer(self) -> Optional[LLMImportanceScorer]:
|
||||
"""Create the memory importance scorer.
|
||||
|
||||
The memory importance scorer is used to score the importance of a memory
|
||||
fragment.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def memory_insight_extractor(self) -> Optional[LLMInsightExtractor]:
|
||||
"""Create the memory insight extractor.
|
||||
|
||||
The memory insight extractor is used to extract a high-level insight from a
|
||||
memory fragment.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def read_memories(
|
||||
self,
|
||||
question: str,
|
||||
) -> str:
|
||||
"""Read the memories from the memory."""
|
||||
memories = await self.memory.read(question)
|
||||
recent_messages = [m.raw_observation for m in memories]
|
||||
return "".join(recent_messages)
|
||||
|
||||
async def save_to_memory(
|
||||
self,
|
||||
question: str,
|
||||
ai_message: str,
|
||||
action_output: Optional[ActionOutput] = None,
|
||||
check_pass: bool = True,
|
||||
check_fail_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Save the role to the memory."""
|
||||
if not action_output:
|
||||
raise ValueError("Action output is required to save to memory.")
|
||||
|
||||
mem_thoughts = action_output.thoughts or ai_message
|
||||
observation = action_output.observations or action_output.content
|
||||
if not check_pass and check_fail_reason:
|
||||
observation += "\n" + check_fail_reason
|
||||
|
||||
memory_map = {
|
||||
"question": question,
|
||||
"thought": mem_thoughts,
|
||||
"action": action_output.action,
|
||||
"observation": observation,
|
||||
}
|
||||
save_memory_template = self.save_memory_template
|
||||
memory_content = self._render_template(save_memory_template, **memory_map)
|
||||
fragment = AgentMemoryFragment(memory_content)
|
||||
await self.memory.write(fragment)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""A proxy agent for the user."""
|
||||
from .base_agent import ConversableAgent
|
||||
from .profile import ProfileConfig
|
||||
|
||||
|
||||
class UserProxyAgent(ConversableAgent):
|
||||
@@ -8,12 +9,13 @@ class UserProxyAgent(ConversableAgent):
|
||||
That can execute code and provide feedback to the other agents.
|
||||
"""
|
||||
|
||||
name: str = "User"
|
||||
profile: str = "Human"
|
||||
|
||||
desc: str = (
|
||||
"A human admin. Interact with the planner to discuss the plan. "
|
||||
"Plan execution needs to be approved by this admin."
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name="User",
|
||||
role="Human",
|
||||
description=(
|
||||
"A human admin. Interact with the planner to discuss the plan. "
|
||||
"Plan execution needs to be approved by this admin."
|
||||
),
|
||||
)
|
||||
|
||||
is_human: bool = True
|
||||
|
Reference in New Issue
Block a user