feat(agent): dbgpts support agent (#1417)

This commit is contained in:
Fangyin Cheng
2024-04-14 23:32:01 +08:00
committed by GitHub
parent 53438a368b
commit 2e2e120ace
12 changed files with 335 additions and 60 deletions

View File

@@ -50,6 +50,15 @@ def create_template(
definition_type,
working_directory,
)
elif dbgpts_type == "agent":
_create_agent_template(
name,
mod_name,
dbgpts_type,
base_metadata,
definition_type,
working_directory,
)
else:
raise ValueError(f"Invalid dbgpts type: {dbgpts_type}")
@@ -111,6 +120,31 @@ def _create_operator_template(
_write_manifest_file(working_directory, name, mod_name)
def _create_agent_template(
name: str,
mod_name: str,
dbgpts_type: str,
base_metadata: dict,
definition_type: str,
working_directory: str,
):
json_dict = {
"agent": base_metadata,
"python_config": {},
"json_config": {},
}
if definition_type != "python":
raise click.ClickException(
f"Unsupported definition type: {definition_type} for dbgpts type: "
f"{dbgpts_type}"
)
_create_poetry_project(working_directory, name)
_write_dbgpts_toml(working_directory, name, json_dict)
_write_agent_init_file(working_directory, name, mod_name)
_write_manifest_file(working_directory, name, mod_name)
def _create_poetry_project(working_directory: str, name: str):
"""Create a new poetry project"""
@@ -207,3 +241,122 @@ class HelloWorldOperator(MapOperator[str, str]):
"""
with open(init_file, "w") as f:
f.write(f'"""{name} operator package"""\n{content}')
def _write_agent_init_file(working_directory: str, name: str, mod_name: str):
"""Write the agent __init__.py file"""
init_file = Path(working_directory) / name / mod_name / "__init__.py"
content = """
import asyncio
from typing import Optional, Tuple
from dbgpt.agent import (
AgentMessage,
Action,
ActionOutput,
AgentResource,
ConversableAgent,
)
from dbgpt.agent.util import cmp_string_equal
_HELLO_WORLD = "Hello world"
class HelloWorldSpeakerAgent(ConversableAgent):
name: str = "Hodor"
profile: str = "HelloWorldSpeaker"
goal: str = f"answer any question from user with '{_HELLO_WORLD}'"
desc: str = f"You can answer any question from user with '{_HELLO_WORLD}'"
constraints: list[str] = [
"You can only answer with '{fix_message}'",
"You can't use any other words",
]
examples: str = (
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n",
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: {_HELLO_WORLD}"
"\\n\\n",
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._init_actions([HelloWorldAction])
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
reply_message = super()._init_reply_message(received_message)
# Fill in the dynamic parameters in the prompt template
reply_message.context = {"fix_message": _HELLO_WORLD}
return reply_message
async def correctness_check(
self, message: AgentMessage
) -> Tuple[bool, Optional[str]]:
action_report = message.action_report
task_result = ""
if action_report:
task_result = action_report.get("content", "")
if not cmp_string_equal(
task_result,
_HELLO_WORLD,
ignore_case=True,
ignore_punctuation=True,
ignore_whitespace=True,
):
return False, f"Please answer with {_HELLO_WORLD}, not '{task_result}'"
return True, None
class HelloWorldAction(Action[None]):
def __init__(self):
super().__init__()
async def run(
self,
ai_message: str,
resource: Optional[AgentResource] = None,
rely_action_out: Optional[ActionOutput] = None,
need_vis_render: bool = True,
**kwargs,
) -> ActionOutput:
return ActionOutput(is_exe_success=True, content=ai_message)
async def _test_agent():
\"\"\"Test the agent.
It will not run in the production environment.
\"\"\"
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.agent import AgentContext, GptsMemory, UserProxyAgent, LLMConfig
llm_client = OpenAILLMClient(model_alias="gpt-3.5-turbo")
context: AgentContext = AgentContext(conv_id="summarize")
default_memory: GptsMemory = GptsMemory()
speaker = (
await HelloWorldSpeakerAgent()
.bind(context)
.bind(LLMConfig(llm_client=llm_client))
.bind(default_memory)
.build()
)
user_proxy = await UserProxyAgent().bind(default_memory).bind(context).build()
await user_proxy.initiate_chat(
recipient=speaker,
reviewer=user_proxy,
message="What's your name?",
)
print(await default_memory.one_chat_completions("summarize"))
if __name__ == "__main__":
asyncio.run(_test_agent())
"""
with open(init_file, "w") as f:
f.write(f'"""{name} agent package."""\n{content}')