mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 12:51:54 +00:00
feat(agent): dbgpts support agent (#1417)
This commit is contained in:
parent
53438a368b
commit
2e2e120ace
@ -7,6 +7,11 @@ from .core.agent import ( # noqa: F401
|
||||
AgentGenerateContext,
|
||||
AgentMessage,
|
||||
)
|
||||
from .core.agent_manage import ( # noqa: F401
|
||||
AgentManager,
|
||||
get_agent_manager,
|
||||
initialize_agent,
|
||||
)
|
||||
from .core.base_agent import ConversableAgent # noqa: F401
|
||||
from .core.llm.llm import LLMConfig # noqa: F401
|
||||
from .core.schema import PluginStorageType # noqa: F401
|
||||
@ -20,6 +25,9 @@ __ALL__ = [
|
||||
"AgentContext",
|
||||
"AgentGenerateContext",
|
||||
"AgentMessage",
|
||||
"AgentManager",
|
||||
"initialize_agent",
|
||||
"get_agent_manager",
|
||||
"ConversableAgent",
|
||||
"Action",
|
||||
"ActionOutput",
|
||||
|
@ -3,14 +3,12 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Type
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, cast
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
|
||||
from ..expand.code_assistant_agent import CodeAssistantAgent
|
||||
from ..expand.dashboard_assistant_agent import DashboardAssistantAgent
|
||||
from ..expand.data_scientist_agent import DataScientistAgent
|
||||
from ..expand.plugin_assistant_agent import PluginAssistantAgent
|
||||
from ..expand.summary_assistant_agent import SummaryAssistantAgent
|
||||
from .agent import Agent
|
||||
from .base_agent import ConversableAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -46,62 +44,117 @@ def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict:
|
||||
return mentions
|
||||
|
||||
|
||||
class AgentManager:
|
||||
class AgentManager(BaseComponent):
|
||||
"""Manages the registration and retrieval of agents."""
|
||||
|
||||
def __init__(self):
|
||||
name = ComponentType.AGENT_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
"""Create a new AgentManager."""
|
||||
self._agents = defaultdict()
|
||||
super().__init__(system_app)
|
||||
self.system_app = system_app
|
||||
self._agents: Dict[
|
||||
str, Tuple[Type[ConversableAgent], ConversableAgent]
|
||||
] = defaultdict()
|
||||
|
||||
def register_agent(self, cls):
|
||||
self._core_agents: Set[str] = set()
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the AgentManager."""
|
||||
self.system_app = system_app
|
||||
|
||||
def after_start(self):
|
||||
"""Register all agents."""
|
||||
from ..expand.code_assistant_agent import CodeAssistantAgent
|
||||
from ..expand.dashboard_assistant_agent import DashboardAssistantAgent
|
||||
from ..expand.data_scientist_agent import DataScientistAgent
|
||||
from ..expand.plugin_assistant_agent import PluginAssistantAgent
|
||||
from ..expand.summary_assistant_agent import SummaryAssistantAgent
|
||||
|
||||
core_agents = set()
|
||||
core_agents.add(self.register_agent(CodeAssistantAgent))
|
||||
core_agents.add(self.register_agent(DashboardAssistantAgent))
|
||||
core_agents.add(self.register_agent(DataScientistAgent))
|
||||
core_agents.add(self.register_agent(SummaryAssistantAgent))
|
||||
core_agents.add(self.register_agent(PluginAssistantAgent))
|
||||
self._core_agents = core_agents
|
||||
|
||||
def register_agent(
|
||||
self, cls: Type[ConversableAgent], ignore_duplicate: bool = False
|
||||
) -> str:
|
||||
"""Register an agent."""
|
||||
self._agents[cls().profile] = cls
|
||||
inst = cls()
|
||||
profile = inst.get_profile()
|
||||
if profile in self._agents and (
|
||||
profile in self._core_agents or not ignore_duplicate
|
||||
):
|
||||
raise ValueError(f"Agent:{profile} already register!")
|
||||
self._agents[profile] = (cls, inst)
|
||||
return profile
|
||||
|
||||
def get_by_name(self, name: str) -> Type[Agent]:
|
||||
def get_by_name(self, name: str) -> Type[ConversableAgent]:
|
||||
"""Return an agent by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent to retrieve.
|
||||
|
||||
Returns:
|
||||
Type[Agent]: The agent with the given name.
|
||||
Type[ConversableAgent]: The agent with the given name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the agent with the given name is not registered.
|
||||
"""
|
||||
if name not in self._agents:
|
||||
raise ValueError(f"Agent:{name} not register!")
|
||||
return self._agents[name]
|
||||
return self._agents[name][0]
|
||||
|
||||
def get_describe_by_name(self, name: str) -> str:
|
||||
"""Return the description of an agent by name."""
|
||||
return self._agents[name].desc
|
||||
return self._agents[name][1].desc
|
||||
|
||||
def all_agents(self):
|
||||
"""Return a dictionary of all registered agents and their descriptions."""
|
||||
result = {}
|
||||
for name, cls in self._agents.items():
|
||||
result[name] = cls.desc
|
||||
for name, value in self._agents.items():
|
||||
result[name] = value[1].desc
|
||||
return result
|
||||
|
||||
def list_agents(self):
|
||||
"""Return a list of all registered agents and their descriptions."""
|
||||
result = []
|
||||
for name, cls in self._agents.items():
|
||||
instance = cls()
|
||||
for name, value in self._agents.items():
|
||||
result.append(
|
||||
{
|
||||
"name": instance.profile,
|
||||
"desc": instance.goal,
|
||||
"name": value[1].profile,
|
||||
"desc": value[1].goal,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
agent_manager = AgentManager()
|
||||
_SYSTEM_APP: Optional[SystemApp] = None
|
||||
|
||||
agent_manager.register_agent(CodeAssistantAgent)
|
||||
agent_manager.register_agent(DashboardAssistantAgent)
|
||||
agent_manager.register_agent(DataScientistAgent)
|
||||
agent_manager.register_agent(SummaryAssistantAgent)
|
||||
agent_manager.register_agent(PluginAssistantAgent)
|
||||
|
||||
def initialize_agent(system_app: SystemApp):
|
||||
"""Initialize the agent manager."""
|
||||
global _SYSTEM_APP
|
||||
_SYSTEM_APP = system_app
|
||||
agent_manager = AgentManager(system_app)
|
||||
system_app.register_instance(agent_manager)
|
||||
|
||||
|
||||
def get_agent_manager(system_app: Optional[SystemApp] = None) -> AgentManager:
|
||||
"""Return the agent manager.
|
||||
|
||||
Args:
|
||||
system_app (Optional[SystemApp], optional): The system app. Defaults to None.
|
||||
|
||||
Returns:
|
||||
AgentManager: The agent manager.
|
||||
"""
|
||||
if not _SYSTEM_APP:
|
||||
if not system_app:
|
||||
system_app = SystemApp()
|
||||
initialize_agent(system_app)
|
||||
app = system_app or _SYSTEM_APP
|
||||
return AgentManager.get_instance(cast(SystemApp, app))
|
||||
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversableAgent(Role, Agent):
|
||||
"""ConversableAgent is a agent that can communicate with other agents."""
|
||||
"""ConversableAgent is an agent that can communicate with other agents."""
|
||||
|
||||
agent_context: Optional[AgentContext] = Field(None, description="Agent context")
|
||||
actions: List[Action] = Field(default_factory=list)
|
||||
|
@ -17,7 +17,7 @@ from dbgpt.core.interface.message import ModelMessageRoleType
|
||||
from dbgpt.model.operators.llm_operator import MixinLLMOperator
|
||||
|
||||
from ...core.agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...core.agent_manage import agent_manager
|
||||
from ...core.agent_manage import get_agent_manager
|
||||
from ...core.base_agent import ConversableAgent
|
||||
from ...core.llm.llm import LLMConfig
|
||||
from .agent_operator_resource import AWELAgent
|
||||
@ -244,7 +244,7 @@ class AWELAgentOperator(
|
||||
) -> ConversableAgent:
|
||||
"""Build the agent."""
|
||||
# agent build
|
||||
agent_cls: Type[ConversableAgent] = agent_manager.get_by_name( # type: ignore
|
||||
agent_cls: Type[ConversableAgent] = get_agent_manager().get_by_name(
|
||||
self.awel_agent.agent_profile
|
||||
)
|
||||
llm_config = self.awel_agent.llm_config
|
||||
|
@ -11,7 +11,7 @@ from dbgpt.core.awel.flow import (
|
||||
register_resource,
|
||||
)
|
||||
|
||||
from ...core.agent_manage import agent_manager
|
||||
from ...core.agent_manage import get_agent_manager
|
||||
from ...core.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
|
||||
@ -118,7 +118,7 @@ class AWELAgentConfig(LLMConfig):
|
||||
def _agent_resource_option_values() -> List[OptionValue]:
|
||||
return [
|
||||
OptionValue(label=item["name"], name=item["name"], value=item["name"])
|
||||
for item in agent_manager.list_agents()
|
||||
for item in get_agent_manager().list_agents()
|
||||
]
|
||||
|
||||
|
||||
|
@ -47,6 +47,7 @@ def initialize_components(
|
||||
)
|
||||
_initialize_model_cache(system_app)
|
||||
_initialize_awel(system_app, param)
|
||||
_initialize_agent(system_app)
|
||||
_initialize_openapi(system_app)
|
||||
# Register serve apps
|
||||
register_serve_apps(system_app, CFG)
|
||||
@ -78,6 +79,12 @@ def _initialize_awel(system_app: SystemApp, param: WebServerParameters):
|
||||
initialize_awel(system_app, dag_dirs)
|
||||
|
||||
|
||||
def _initialize_agent(system_app: SystemApp):
|
||||
from dbgpt.agent import initialize_agent
|
||||
|
||||
initialize_agent(system_app)
|
||||
|
||||
|
||||
def _initialize_openapi(system_app: SystemApp):
|
||||
from dbgpt.app.openapi.api_v1.editor.service import EditorService
|
||||
|
||||
|
@ -86,6 +86,7 @@ class ComponentType(str, Enum):
|
||||
AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager"
|
||||
UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory"
|
||||
CONNECTOR_MANAGER = "dbgpt_connector_manager"
|
||||
AGENT_MANAGER = "dbgpt_agent_manager"
|
||||
|
||||
|
||||
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
|
||||
|
@ -10,7 +10,7 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.core.agent import Agent, AgentContext
|
||||
from dbgpt.agent.core.agent_manage import agent_manager
|
||||
from dbgpt.agent.core.agent_manage import get_agent_manager
|
||||
from dbgpt.agent.core.base_agent import ConversableAgent
|
||||
from dbgpt.agent.core.llm.llm import LLMConfig, LLMStrategyType
|
||||
from dbgpt.agent.core.schema import Status
|
||||
@ -222,7 +222,9 @@ class MultiAgents(BaseComponent, ABC):
|
||||
self.llm_provider = DefaultLLMClient(worker_manager, auto_convert_message=True)
|
||||
|
||||
for record in gpts_app.details:
|
||||
cls: Type[ConversableAgent] = agent_manager.get_by_name(record.agent_name)
|
||||
cls: Type[ConversableAgent] = get_agent_manager().get_by_name(
|
||||
record.agent_name
|
||||
)
|
||||
llm_config = LLMConfig(
|
||||
llm_client=self.llm_provider,
|
||||
llm_strategy=LLMStrategyType(record.llm_strategy),
|
||||
@ -340,7 +342,7 @@ multi_agents = MultiAgents()
|
||||
async def agents_list():
|
||||
logger.info("agents_list!")
|
||||
try:
|
||||
agents = agent_manager.all_agents()
|
||||
agents = get_agent_manager().all_agents()
|
||||
return Result.succ(agents)
|
||||
except Exception as e:
|
||||
return Result.failed(code="E30001", msg=str(e))
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
from fastapi import APIRouter
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.core.agent_manage import agent_manager
|
||||
from dbgpt.agent.core.agent_manage import get_agent_manager
|
||||
from dbgpt.agent.core.llm.llm import LLMStrategyType
|
||||
from dbgpt.agent.resource.resource_api import ResourceType
|
||||
from dbgpt.app.knowledge.api import knowledge_space_service
|
||||
@ -63,7 +63,7 @@ async def edit(gpts_app: GptsApp):
|
||||
@router.get("/v1/agents/list")
|
||||
async def all_agents():
|
||||
try:
|
||||
return Result.succ(agent_manager.list_agents())
|
||||
return Result.succ(get_agent_manager().list_agents())
|
||||
except Exception as ex:
|
||||
return Result.failed(code="E000X", msg=f"query agents error: {ex}")
|
||||
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, cast
|
||||
|
||||
import schedule
|
||||
import tomlkit
|
||||
@ -18,6 +18,7 @@ from dbgpt.util.dbgpts.base import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BasePackage(BaseModel):
|
||||
@ -68,6 +69,32 @@ class BasePackage(BaseModel):
|
||||
def abs_definition_file(self) -> str:
|
||||
return str(Path(self.path) / self.definition_file)
|
||||
|
||||
@classmethod
|
||||
def load_module_class(
|
||||
cls, values: Dict[str, Any], expected_cls: Type[T]
|
||||
) -> List[Type[T]]:
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if not name:
|
||||
raise ValueError("The name is required")
|
||||
if not root:
|
||||
raise ValueError("The root is required")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
module_cls = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, expected_cls):
|
||||
module_cls.append(c)
|
||||
return module_cls
|
||||
|
||||
|
||||
class FlowPackage(BasePackage):
|
||||
package_type = "flow"
|
||||
@ -107,24 +134,24 @@ class OperatorPackage(BasePackage):
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
|
||||
name = values.get("name")
|
||||
root = values.get("root")
|
||||
if root not in sys.path:
|
||||
sys.path.append(root)
|
||||
with pkg_resources.path(name, "__init__.py") as path:
|
||||
mods = _load_modules_from_file(str(path), name, show_log=False)
|
||||
all_cls = [_get_classes_from_module(m) for m in mods]
|
||||
operators = []
|
||||
for list_cls in all_cls:
|
||||
for c in list_cls:
|
||||
if issubclass(c, BaseOperator):
|
||||
operators.append(c)
|
||||
values["operators"] = operators
|
||||
values["operators"] = cls.load_module_class(values, BaseOperator)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
class AgentPackage(BasePackage):
|
||||
package_type = "agent"
|
||||
|
||||
agents: List[type] = Field(
|
||||
default_factory=list, description="The agents of the package"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.agent import ConversableAgent
|
||||
|
||||
values["agents"] = cls.load_module_class(values, ConversableAgent)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@ -153,11 +180,14 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
pkg_dict = {}
|
||||
for key, value in metadata.items():
|
||||
if key == "flow":
|
||||
pkg_dict = value
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "flow"
|
||||
elif key == "operator":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "operator"
|
||||
elif key == "agent":
|
||||
pkg_dict = {k: v for k, v in value.items()}
|
||||
pkg_dict["package_type"] = "agent"
|
||||
else:
|
||||
ext_metadata[key] = value
|
||||
pkg_dict["root"] = package.root
|
||||
@ -167,6 +197,8 @@ def _parse_package_metadata(package: InstalledPackage) -> BasePackage:
|
||||
return FlowPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "operator":
|
||||
return OperatorPackage.build_from(pkg_dict, ext_metadata)
|
||||
elif pkg_dict["package_type"] == "agent":
|
||||
return AgentPackage.build_from(pkg_dict, ext_metadata)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported package package_type: {pkg_dict['package_type']}"
|
||||
@ -243,6 +275,7 @@ class DBGPTsLoader(BaseComponent):
|
||||
)
|
||||
for package in packages:
|
||||
self._packages[package.name] = package
|
||||
self._register_packages(package)
|
||||
except Exception as e:
|
||||
logger.warning(f"Load dbgpts package error: {e}")
|
||||
|
||||
@ -268,3 +301,16 @@ class DBGPTsLoader(BaseComponent):
|
||||
}
|
||||
panels.append(FlowPanel(**dict_value))
|
||||
return panels
|
||||
|
||||
def _register_packages(self, package: BasePackage):
|
||||
if package.package_type == "agent":
|
||||
from dbgpt.agent import ConversableAgent, get_agent_manager
|
||||
|
||||
agent_manager = get_agent_manager(self._system_app)
|
||||
pkg = cast(AgentPackage, package)
|
||||
for agent_cls in pkg.agents:
|
||||
if issubclass(agent_cls, ConversableAgent):
|
||||
try:
|
||||
agent_manager.register_agent(agent_cls, ignore_duplicate=True)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Register agent {agent_cls} error: {e}")
|
||||
|
@ -73,8 +73,7 @@ def _print_repos():
|
||||
table.add_column(_("Repository"), justify="right", style="cyan", no_wrap=True)
|
||||
table.add_column(_("Path"), justify="right", style="green")
|
||||
for repo, full_path in repos:
|
||||
if full_path.startswith(str(Path.home())):
|
||||
full_path = full_path.replace(str(Path.home()), "~")
|
||||
full_path = _print_path(full_path)
|
||||
table.add_row(repo, full_path)
|
||||
cl.print(table)
|
||||
|
||||
@ -245,7 +244,14 @@ def _copy_and_install(repo: str, name: str, package_path: Path):
|
||||
shutil.copytree(package_path, install_path)
|
||||
cl.info(f"Installing dbgpts '{name}' from {repo}...")
|
||||
os.chdir(install_path)
|
||||
subprocess.run(["poetry", "install"], check=True)
|
||||
subprocess.run(["poetry", "build"], check=True)
|
||||
wheel_files = list(install_path.glob("dist/*.whl"))
|
||||
if not wheel_files:
|
||||
cl.error("No wheel file found after building the package.", exit_code=1)
|
||||
# Install the wheel file using pip
|
||||
wheel_file = wheel_files[0]
|
||||
cl.info(f"Installing dbgpts '{name}' wheel file {_print_path(wheel_file)}...")
|
||||
subprocess.run(["pip", "install", str(wheel_file)], check=True)
|
||||
_write_install_metadata(name, repo, install_path)
|
||||
cl.success(f"Installed dbgpts at {_print_path(install_path)}.")
|
||||
cl.success(f"dbgpts '{name}' installed successfully.")
|
||||
@ -357,7 +363,6 @@ def list_installed_apps():
|
||||
packages.sort(key=lambda x: (x.package, x.package_type, x.repo))
|
||||
for package in packages:
|
||||
str_path = package.root
|
||||
if str_path.startswith(str(Path.home())):
|
||||
str_path = str_path.replace(str(Path.home()), "~")
|
||||
str_path = _print_path(str_path)
|
||||
table.add_row(package.package, package.package_type, package.repo, str_path)
|
||||
cl.print(table)
|
||||
|
@ -50,6 +50,15 @@ def create_template(
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
elif dbgpts_type == "agent":
|
||||
_create_agent_template(
|
||||
name,
|
||||
mod_name,
|
||||
dbgpts_type,
|
||||
base_metadata,
|
||||
definition_type,
|
||||
working_directory,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid dbgpts type: {dbgpts_type}")
|
||||
|
||||
@ -111,6 +120,31 @@ def _create_operator_template(
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_agent_template(
|
||||
name: str,
|
||||
mod_name: str,
|
||||
dbgpts_type: str,
|
||||
base_metadata: dict,
|
||||
definition_type: str,
|
||||
working_directory: str,
|
||||
):
|
||||
json_dict = {
|
||||
"agent": base_metadata,
|
||||
"python_config": {},
|
||||
"json_config": {},
|
||||
}
|
||||
if definition_type != "python":
|
||||
raise click.ClickException(
|
||||
f"Unsupported definition type: {definition_type} for dbgpts type: "
|
||||
f"{dbgpts_type}"
|
||||
)
|
||||
|
||||
_create_poetry_project(working_directory, name)
|
||||
_write_dbgpts_toml(working_directory, name, json_dict)
|
||||
_write_agent_init_file(working_directory, name, mod_name)
|
||||
_write_manifest_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_poetry_project(working_directory: str, name: str):
|
||||
"""Create a new poetry project"""
|
||||
|
||||
@ -207,3 +241,122 @@ class HelloWorldOperator(MapOperator[str, str]):
|
||||
"""
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} operator package"""\n{content}')
|
||||
|
||||
|
||||
def _write_agent_init_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the agent __init__.py file"""
|
||||
|
||||
init_file = Path(working_directory) / name / mod_name / "__init__.py"
|
||||
content = """
|
||||
import asyncio
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from dbgpt.agent import (
|
||||
AgentMessage,
|
||||
Action,
|
||||
ActionOutput,
|
||||
AgentResource,
|
||||
ConversableAgent,
|
||||
)
|
||||
from dbgpt.agent.util import cmp_string_equal
|
||||
|
||||
_HELLO_WORLD = "Hello world"
|
||||
|
||||
|
||||
class HelloWorldSpeakerAgent(ConversableAgent):
|
||||
name: str = "Hodor"
|
||||
profile: str = "HelloWorldSpeaker"
|
||||
goal: str = f"answer any question from user with '{_HELLO_WORLD}'"
|
||||
desc: str = f"You can answer any question from user with '{_HELLO_WORLD}'"
|
||||
constraints: list[str] = [
|
||||
"You can only answer with '{fix_message}'",
|
||||
"You can't use any other words",
|
||||
]
|
||||
examples: str = (
|
||||
f"user: What's your name?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: What's the weather today?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Can you help me?\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please tell me a joke.\\nassistant: {_HELLO_WORLD}\\n\\n",
|
||||
f"user: Please answer me without '{_HELLO_WORLD}'.\\nassistant: {_HELLO_WORLD}"
|
||||
"\\n\\n",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([HelloWorldAction])
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
# Fill in the dynamic parameters in the prompt template
|
||||
reply_message.context = {"fix_message": _HELLO_WORLD}
|
||||
return reply_message
|
||||
|
||||
async def correctness_check(
|
||||
self, message: AgentMessage
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
action_report = message.action_report
|
||||
task_result = ""
|
||||
if action_report:
|
||||
task_result = action_report.get("content", "")
|
||||
if not cmp_string_equal(
|
||||
task_result,
|
||||
_HELLO_WORLD,
|
||||
ignore_case=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_whitespace=True,
|
||||
):
|
||||
return False, f"Please answer with {_HELLO_WORLD}, not '{task_result}'"
|
||||
return True, None
|
||||
|
||||
|
||||
class HelloWorldAction(Action[None]):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ai_message: str,
|
||||
resource: Optional[AgentResource] = None,
|
||||
rely_action_out: Optional[ActionOutput] = None,
|
||||
need_vis_render: bool = True,
|
||||
**kwargs,
|
||||
) -> ActionOutput:
|
||||
return ActionOutput(is_exe_success=True, content=ai_message)
|
||||
|
||||
|
||||
async def _test_agent():
|
||||
\"\"\"Test the agent.
|
||||
|
||||
It will not run in the production environment.
|
||||
\"\"\"
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.agent import AgentContext, GptsMemory, UserProxyAgent, LLMConfig
|
||||
|
||||
llm_client = OpenAILLMClient(model_alias="gpt-3.5-turbo")
|
||||
context: AgentContext = AgentContext(conv_id="summarize")
|
||||
|
||||
default_memory: GptsMemory = GptsMemory()
|
||||
|
||||
speaker = (
|
||||
await HelloWorldSpeakerAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(default_memory)
|
||||
.build()
|
||||
)
|
||||
|
||||
user_proxy = await UserProxyAgent().bind(default_memory).bind(context).build()
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=speaker,
|
||||
reviewer=user_proxy,
|
||||
message="What's your name?",
|
||||
)
|
||||
print(await default_memory.one_chat_completions("summarize"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(_test_agent())
|
||||
|
||||
"""
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} agent package."""\n{content}')
|
||||
|
Loading…
Reference in New Issue
Block a user