mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
refactor(agent): Agent modular refactoring (#1487)
This commit is contained in:
@@ -1,112 +1,93 @@
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from jinja2.meta import find_undeclared_variables
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .action.base import ActionOutput
|
||||
from .memory.agent_memory import AgentMemory, AgentMemoryFragment
|
||||
from .memory.llm import LLMImportanceScorer, LLMInsightExtractor
|
||||
from .profile import Profile, ProfileConfig
|
||||
|
||||
|
||||
class Role(ABC, BaseModel):
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = ""
|
||||
name: str = ""
|
||||
resource_introduction: str = ""
|
||||
goal: str = ""
|
||||
|
||||
expand_prompt: str = ""
|
||||
profile: ProfileConfig = Field(
|
||||
...,
|
||||
description="The profile of the role.",
|
||||
)
|
||||
memory: AgentMemory = Field(default_factory=AgentMemory)
|
||||
|
||||
fixed_subgoal: Optional[str] = Field(None, description="Fixed subgoal")
|
||||
|
||||
constraints: List[str] = Field(default_factory=list, description="Constraints")
|
||||
examples: str = ""
|
||||
desc: str = ""
|
||||
language: str = "en"
|
||||
is_human: bool = False
|
||||
is_team: bool = False
|
||||
|
||||
def prompt_template(
|
||||
template_env: SandboxedEnvironment = Field(default_factory=SandboxedEnvironment)
|
||||
|
||||
async def build_prompt(
|
||||
self,
|
||||
specified_prompt: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
is_system: bool = True,
|
||||
most_recent_memories: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""Return the prompt template for the role.
|
||||
|
||||
Args:
|
||||
specified_prompt (str, optional): The specified prompt. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The prompt template.
|
||||
"""
|
||||
if specified_prompt:
|
||||
return specified_prompt
|
||||
|
||||
expand_prompt = self.expand_prompt if len(self.expand_prompt) > 0 else ""
|
||||
examples_prompt = (
|
||||
"You can refer to the following examples:\n"
|
||||
if len(self.examples) > 0
|
||||
else ""
|
||||
prompt_template = (
|
||||
self.system_prompt_template if is_system else self.user_prompt_template
|
||||
)
|
||||
examples = self.examples if len(self.examples) > 0 else ""
|
||||
template = (
|
||||
f"{self.role_prompt}\n"
|
||||
"Please think step by step to achieve the goal. You can use the resources "
|
||||
"given below. At the same time, please strictly abide by the constraints "
|
||||
"and specifications in IMPORTANT REMINDER.\n\n"
|
||||
f"{{resource_prompt}}\n\n"
|
||||
f"{expand_prompt}\n\n"
|
||||
"*** IMPORTANT REMINDER ***\n"
|
||||
f"{self.language_require_prompt}\n"
|
||||
f"{self.constraints_prompt}\n"
|
||||
f"{examples_prompt}{examples}\n\n"
|
||||
f"{{out_schema}}"
|
||||
)
|
||||
return template
|
||||
template_vars = self._get_template_variables(prompt_template)
|
||||
_sub_render_keys = {"role", "name", "goal", "expand_prompt", "constraints"}
|
||||
pass_vars = {
|
||||
"role": self.role,
|
||||
"name": self.name,
|
||||
"goal": self.goal,
|
||||
"expand_prompt": self.expand_prompt,
|
||||
"language": self.language,
|
||||
"constraints": self.constraints,
|
||||
"most_recent_memories": (
|
||||
most_recent_memories if most_recent_memories else None
|
||||
),
|
||||
"examples": self.examples,
|
||||
# "out_schema": out_schema if out_schema else None,
|
||||
# "resource_prompt": resource_prompt if resource_prompt else None,
|
||||
"question": question,
|
||||
}
|
||||
resource_vars = await self.generate_resource_variables(question)
|
||||
pass_vars.update(resource_vars)
|
||||
pass_vars.update(kwargs)
|
||||
filtered_data = {
|
||||
key: pass_vars[key] for key in template_vars if key in pass_vars
|
||||
}
|
||||
for key in filtered_data.keys():
|
||||
value = filtered_data[key]
|
||||
if key in _sub_render_keys and value:
|
||||
if isinstance(value, str):
|
||||
# Render the sub-template
|
||||
filtered_data[key] = self._render_template(value, **pass_vars)
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, str):
|
||||
value[i] = self._render_template(item, **pass_vars)
|
||||
return self._render_template(prompt_template, **filtered_data)
|
||||
|
||||
@property
|
||||
def role_prompt(self) -> str:
|
||||
"""Return the role prompt.
|
||||
|
||||
You are a {self.profile}, named {self.name}, your goal is {self.goal}.
|
||||
|
||||
Returns:
|
||||
str: The role prompt.
|
||||
"""
|
||||
profile_prompt = f"You are a {self.profile},"
|
||||
name_prompt = f"named {self.name}," if len(self.name) > 0 else ""
|
||||
goal_prompt = f"your goal is {self.goal}"
|
||||
prompt = f"""{profile_prompt}{name_prompt}{goal_prompt}"""
|
||||
return prompt
|
||||
|
||||
@property
|
||||
def constraints_prompt(self) -> str:
|
||||
"""Return the constraints prompt.
|
||||
|
||||
Return:
|
||||
str: The constraints prompt.
|
||||
"""
|
||||
if len(self.constraints) > 0:
|
||||
return "\n".join(
|
||||
f"{i + 1}. {item}" for i, item in enumerate(self.constraints)
|
||||
)
|
||||
return ""
|
||||
|
||||
@property
|
||||
def language_require_prompt(self) -> str:
|
||||
"""Return the language requirement prompt.
|
||||
|
||||
Returns:
|
||||
str: The language requirement prompt.
|
||||
"""
|
||||
if self.language == "zh":
|
||||
return "Please answer in simplified Chinese."
|
||||
else:
|
||||
return "Please answer in English."
|
||||
|
||||
@property
|
||||
def introduce(self) -> str:
|
||||
"""Introduce the role."""
|
||||
return self.desc
|
||||
async def generate_resource_variables(
|
||||
self, question: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate the resource variables."""
|
||||
return {}
|
||||
|
||||
def identity_check(self) -> None:
|
||||
"""Check the identity of the role."""
|
||||
@@ -114,12 +95,123 @@ class Role(ABC, BaseModel):
|
||||
|
||||
def get_name(self) -> str:
|
||||
"""Get the name of the role."""
|
||||
return self.name
|
||||
return self.current_profile.get_name()
|
||||
|
||||
def get_profile(self) -> str:
|
||||
"""Get the profile of the role."""
|
||||
return self.profile
|
||||
@property
|
||||
def current_profile(self) -> Profile:
|
||||
"""Return the current profile."""
|
||||
profile = self.profile.create_profile()
|
||||
return profile
|
||||
|
||||
def get_describe(self) -> str:
|
||||
"""Get the describe of the role."""
|
||||
return self.desc
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of the role."""
|
||||
return self.current_profile.get_name()
|
||||
|
||||
@property
|
||||
def examples(self) -> Optional[str]:
|
||||
"""Return the examples of the role."""
|
||||
return self.current_profile.get_examples()
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
"""Return the role of the role."""
|
||||
return self.current_profile.get_role()
|
||||
|
||||
@property
|
||||
def goal(self) -> Optional[str]:
|
||||
"""Return the goal of the role."""
|
||||
return self.current_profile.get_goal()
|
||||
|
||||
@property
|
||||
def constraints(self) -> Optional[List[str]]:
|
||||
"""Return the constraints of the role."""
|
||||
return self.current_profile.get_constraints()
|
||||
|
||||
@property
|
||||
def desc(self) -> Optional[str]:
|
||||
"""Return the description of the role."""
|
||||
return self.current_profile.get_description()
|
||||
|
||||
@property
|
||||
def expand_prompt(self) -> Optional[str]:
|
||||
"""Return the expand prompt of the role."""
|
||||
return self.current_profile.get_expand_prompt()
|
||||
|
||||
@property
|
||||
def system_prompt_template(self) -> str:
|
||||
"""Return the current system prompt template."""
|
||||
return self.current_profile.get_system_prompt_template()
|
||||
|
||||
@property
|
||||
def user_prompt_template(self) -> str:
|
||||
"""Return the current user prompt template."""
|
||||
return self.current_profile.get_user_prompt_template()
|
||||
|
||||
@property
|
||||
def save_memory_template(self) -> str:
|
||||
"""Return the current save memory template."""
|
||||
return self.current_profile.get_save_memory_template()
|
||||
|
||||
def _get_template_variables(self, template: str) -> Set[str]:
|
||||
parsed_content = self.template_env.parse(template)
|
||||
return find_undeclared_variables(parsed_content)
|
||||
|
||||
def _render_template(self, template: str, **kwargs):
|
||||
r_template = self.template_env.from_string(template)
|
||||
return r_template.render(**kwargs)
|
||||
|
||||
@property
|
||||
def memory_importance_scorer(self) -> Optional[LLMImportanceScorer]:
|
||||
"""Create the memory importance scorer.
|
||||
|
||||
The memory importance scorer is used to score the importance of a memory
|
||||
fragment.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def memory_insight_extractor(self) -> Optional[LLMInsightExtractor]:
|
||||
"""Create the memory insight extractor.
|
||||
|
||||
The memory insight extractor is used to extract a high-level insight from a
|
||||
memory fragment.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def read_memories(
|
||||
self,
|
||||
question: str,
|
||||
) -> str:
|
||||
"""Read the memories from the memory."""
|
||||
memories = await self.memory.read(question)
|
||||
recent_messages = [m.raw_observation for m in memories]
|
||||
return "".join(recent_messages)
|
||||
|
||||
async def save_to_memory(
|
||||
self,
|
||||
question: str,
|
||||
ai_message: str,
|
||||
action_output: Optional[ActionOutput] = None,
|
||||
check_pass: bool = True,
|
||||
check_fail_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Save the role to the memory."""
|
||||
if not action_output:
|
||||
raise ValueError("Action output is required to save to memory.")
|
||||
|
||||
mem_thoughts = action_output.thoughts or ai_message
|
||||
observation = action_output.observations or action_output.content
|
||||
if not check_pass and check_fail_reason:
|
||||
observation += "\n" + check_fail_reason
|
||||
|
||||
memory_map = {
|
||||
"question": question,
|
||||
"thought": mem_thoughts,
|
||||
"action": action_output.action,
|
||||
"observation": observation,
|
||||
}
|
||||
save_memory_template = self.save_memory_template
|
||||
memory_content = self._render_template(save_memory_template, **memory_map)
|
||||
fragment = AgentMemoryFragment(memory_content)
|
||||
await self.memory.write(fragment)
|
||||
|
Reference in New Issue
Block a user