mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 20:10:08 +00:00
feat(agent): dbgpts support agent (#1417)
This commit is contained in:
@@ -50,6 +50,15 @@ def create_template(
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
elif dbgpts_type == "agent":
|
||||
_create_agent_template(
|
||||
name,
|
||||
mod_name,
|
||||
dbgpts_type,
|
||||
base_metadata,
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid dbgpts type: {dbgpts_type}")
|
||||
|
||||
@@ -111,6 +120,31 @@ def _create_operator_template(
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_agent_template(
|
||||
name: str,
|
||||
mod_name: str,
|
||||
dbgpts_type: str,
|
||||
base_metadata: dict,
|
||||
definition_type: str,
|
||||
working_directory: str,
|
||||
):
|
||||
json_dict = {
|
||||
"agent": base_metadata,
|
||||
"python_config": {},
|
||||
"json_config": {},
|
||||
}
|
||||
if definition_type != "python":
|
||||
raise click.ClickException(
|
||||
f"Unsupported definition type: {definition_type} for dbgpts type: "
|
||||
f"{dbgpts_type}"
|
||||
)
|
||||
|
||||
_create_poetry_project(working_directory, name)
|
||||
_write_dbgpts_toml(working_directory, name, json_dict)
|
||||
_write_agent_init_file(working_directory, name, mod_name)
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_poetry_project(working_directory: str, name: str):
|
||||
"""Create a new poetry project"""
|
||||
|
||||
@@ -207,3 +241,122 @@ class HelloWorldOperator(MapOperator[str, str]):
|
||||
"""
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} operator package"""\n{content}')
|
||||
|
||||
|
||||
def _write_agent_init_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the agent __init__.py file"""
|
||||
|
||||
init_file = Path(working_directory) / name / mod_name / "__init__.py"
|
||||
content = """
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from dbgpt.agent import (
|
||||
AgentMessage,
|
||||
Action,
|
||||
ActionOutput,
|
||||
AgentResource,
|
||||
ConversableAgent,
|
||||
)
|
||||
from dbgpt.agent.util import cmp_string_equal
|
||||
|
||||
_HELLO_WORLD = "Hello world"
|
||||
|
||||
|
||||
class HelloWorldSpeakerAgent(ConversableAgent):
|
||||
name: str = "Hodor"
|
||||
profile: str = "HelloWorldSpeaker"
|
||||
goal: str = f"answer any question from user with '{_HELLO_WORLD}'"
|
||||
desc: str = f"You can answer any question from user with '{_HELLO_WORLD}'"
|
||||
constraints: list[str] = [
|
||||
"You can only answer with '{fix_message}'",
|
||||
"You can't use any other words",
|
||||
]
|
||||
examples: str = (
|
||||
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: {_HELLO_WORLD}"
|
||||
"\\n\\n",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([HelloWorldAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
# Fill in the dynamic parameters in the prompt template
|
||||
reply_message.context = {"fix_message": _HELLO_WORLD}
|
||||
return reply_message
|
||||
|
||||
async def correctness_check(
|
||||
self, message: AgentMessage
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
action_report = message.action_report
|
||||
task_result = ""
|
||||
if action_report:
|
||||
task_result = action_report.get("content", "")
|
||||
if not cmp_string_equal(
|
||||
task_result,
|
||||
_HELLO_WORLD,
|
||||
ignore_case=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_whitespace=True,
|
||||
):
|
||||
return False, f"Please answer with {_HELLO_WORLD}, not '{task_result}'"
|
||||
return True, None
|
||||
|
||||
|
||||
class HelloWorldAction(Action[None]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
resource: Optional[AgentResource] = None,
|
||||
rely_action_out: Optional[ActionOutput] = None,
|
||||
need_vis_render: bool = True,
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
return ActionOutput(is_exe_success=True, content=ai_message)
|
||||
|
||||
|
||||
async def _test_agent():
|
||||
\"\"\"Test the agent.
|
||||
|
||||
It will not run in the production environment.
|
||||
\"\"\"
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.agent import AgentContext, GptsMemory, UserProxyAgent, LLMConfig
|
||||
|
||||
llm_client = OpenAILLMClient(model_alias="gpt-3.5-turbo")
|
||||
context: AgentContext = AgentContext(conv_id="summarize")
|
||||
|
||||
default_memory: GptsMemory = GptsMemory()
|
||||
|
||||
speaker = (
|
||||
await HelloWorldSpeakerAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(default_memory)
|
||||
.build()
|
||||
)
|
||||
|
||||
user_proxy = await UserProxyAgent().bind(default_memory).bind(context).build()
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=speaker,
|
||||
reviewer=user_proxy,
|
||||
message="What's your name?",
|
||||
)
|
||||
print(await default_memory.one_chat_completions("summarize"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_test_agent())
|
||||
|
||||
"""
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} agent package."""\n{content}')
|
||||
|
Reference in New Issue
Block a user