mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
refactor(agent): Refactor resource of agents (#1518)
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
#*******************************************************************#
|
||||
#** DB-GPT - GENERAL SETTINGS **#
|
||||
#*******************************************************************#
|
||||
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled. Each of the below are an option:
|
||||
## pilot.commands.query_execute
|
||||
|
||||
## For example, to disable coding related features, uncomment the next line
|
||||
# DISABLED_COMMAND_CATEGORIES=
|
||||
|
||||
#*******************************************************************#
|
||||
#** Webserver Port **#
|
||||
@@ -125,25 +120,6 @@ LOCAL_DB_TYPE=sqlite
|
||||
#*******************************************************************#
|
||||
EXECUTE_LOCAL_COMMANDS=False
|
||||
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** ALLOWLISTED PLUGINS **#
|
||||
#*******************************************************************#
|
||||
|
||||
#ALLOWLISTED_PLUGINS - Sets the listed plugins that are allowed (Example: plugin1,plugin2,plugin3)
|
||||
#DENYLISTED_PLUGINS - Sets the listed plugins that are not allowed (Example: plugin1,plugin2,plugin3)
|
||||
ALLOWLISTED_PLUGINS=
|
||||
DENYLISTED_PLUGINS=
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** CHAT PLUGIN SETTINGS **#
|
||||
#*******************************************************************#
|
||||
# CHAT_MESSAGES_ENABLED - Enable chat messages (Default: False)
|
||||
# CHAT_MESSAGES_ENABLED=False
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** VECTOR STORE SETTINGS **#
|
||||
#*******************************************************************#
|
||||
|
@@ -3,14 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from dbgpt.util.singleton import Singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from dbgpt.agent.plugin import CommandRegistry
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.datasource.manages import ConnectorManager
|
||||
|
||||
@@ -165,14 +162,6 @@ class Config(metaclass=Singleton):
|
||||
from dbgpt.core._private.prompt_registry import PromptTemplateRegistry
|
||||
|
||||
self.prompt_template_registry = PromptTemplateRegistry()
|
||||
### Related configuration of built-in commands
|
||||
self.command_registry: Optional[CommandRegistry] = None
|
||||
|
||||
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES")
|
||||
if disabled_command_categories:
|
||||
self.disabled_command_categories = disabled_command_categories.split(",")
|
||||
else:
|
||||
self.disabled_command_categories = []
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
|
||||
@@ -180,25 +169,6 @@ class Config(metaclass=Singleton):
|
||||
### message stor file
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
|
||||
self.plugins: List["AutoGPTPluginTemplate"] = []
|
||||
self.plugins_openai = [] # type: ignore
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"
|
||||
|
||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||
|
||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||
if plugins_allowlist:
|
||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||
else:
|
||||
self.plugins_allowlist = []
|
||||
|
||||
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||
if plugins_denylist:
|
||||
self.plugins_denylist = plugins_denylist.split(",")
|
||||
else:
|
||||
self.plugins_denylist = []
|
||||
### Native SQL Execution Capability Control Configuration
|
||||
self.NATIVE_SQL_CAN_RUN_DDL = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true"
|
||||
|
@@ -19,8 +19,7 @@ from .core.plan import * # noqa: F401, F403
|
||||
from .core.profile import * # noqa: F401, F403
|
||||
from .core.schema import PluginStorageType # noqa: F401
|
||||
from .core.user_proxy_agent import UserProxyAgent # noqa: F401
|
||||
from .resource.resource_api import AgentResource, ResourceType # noqa: F401
|
||||
from .resource.resource_loader import ResourceLoader # noqa: F401
|
||||
from .resource.base import AgentResource, Resource, ResourceType # noqa: F401
|
||||
from .util.llm.llm import LLMConfig # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
@@ -38,7 +37,6 @@ __ALL__ = [
|
||||
"GptsMemory",
|
||||
"AgentResource",
|
||||
"ResourceType",
|
||||
"ResourceLoader",
|
||||
"PluginStorageType",
|
||||
"UserProxyAgent",
|
||||
]
|
||||
|
@@ -27,8 +27,7 @@ from dbgpt._private.pydantic import (
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
from dbgpt.vis.base import Vis
|
||||
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_loader import ResourceLoader
|
||||
from ...resource.base import AgentResource, Resource, ResourceType
|
||||
|
||||
T = TypeVar("T", bound=Union[BaseModel, List[BaseModel], None])
|
||||
|
||||
@@ -77,11 +76,11 @@ class Action(ABC, Generic[T]):
|
||||
|
||||
def __init__(self):
|
||||
"""Create an action."""
|
||||
self.resource_loader: Optional[ResourceLoader] = None
|
||||
self.resource: Optional[Resource] = None
|
||||
|
||||
def init_resource_loader(self, resource_loader: Optional[ResourceLoader]):
|
||||
"""Initialize the resource loader."""
|
||||
self.resource_loader = resource_loader
|
||||
def init_resource(self, resource: Optional[Resource]):
|
||||
"""Initialize the resource."""
|
||||
self.resource = resource
|
||||
|
||||
@property
|
||||
def resource_need(self) -> Optional[ResourceType]:
|
||||
|
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from ...resource.resource_api import AgentResource
|
||||
from ...resource.base import AgentResource
|
||||
from .base import Action, ActionOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@@ -9,7 +9,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
|
||||
from ..resource.resource_loader import ResourceLoader
|
||||
from .action.base import ActionOutput
|
||||
from .memory.agent_memory import AgentMemory
|
||||
|
||||
@@ -209,7 +208,6 @@ class AgentGenerateContext:
|
||||
|
||||
memory: Optional[AgentMemory] = None
|
||||
agent_context: Optional[AgentContext] = None
|
||||
resource_loader: Optional[ResourceLoader] = None
|
||||
llm_client: Optional[LLMClient] = None
|
||||
|
||||
round_index: Optional[int] = None
|
||||
|
@@ -68,15 +68,15 @@ class AgentManager(BaseComponent):
|
||||
from ..expand.code_assistant_agent import CodeAssistantAgent
|
||||
from ..expand.dashboard_assistant_agent import DashboardAssistantAgent
|
||||
from ..expand.data_scientist_agent import DataScientistAgent
|
||||
from ..expand.plugin_assistant_agent import PluginAssistantAgent
|
||||
from ..expand.summary_assistant_agent import SummaryAssistantAgent
|
||||
from ..expand.tool_assistant_agent import ToolAssistantAgent
|
||||
|
||||
core_agents = set()
|
||||
core_agents.add(self.register_agent(CodeAssistantAgent))
|
||||
core_agents.add(self.register_agent(DashboardAssistantAgent))
|
||||
core_agents.add(self.register_agent(DataScientistAgent))
|
||||
core_agents.add(self.register_agent(SummaryAssistantAgent))
|
||||
core_agents.add(self.register_agent(PluginAssistantAgent))
|
||||
core_agents.add(self.register_agent(ToolAssistantAgent))
|
||||
self._core_agents = core_agents
|
||||
|
||||
def register_agent(
|
||||
|
@@ -3,16 +3,17 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, cast
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import LLMClient, ModelMessageRoleType
|
||||
from dbgpt.util.error_types import LLMChatError
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
from dbgpt.util.utils import colored
|
||||
|
||||
from ..resource.resource_api import AgentResource, ResourceClient
|
||||
from ..resource.resource_loader import ResourceLoader
|
||||
from ..resource.base import Resource
|
||||
from ..util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ..util.llm.llm_client import AIWrapper
|
||||
from .action.base import Action, ActionOutput
|
||||
@@ -32,12 +33,15 @@ class ConversableAgent(Role, Agent):
|
||||
|
||||
agent_context: Optional[AgentContext] = Field(None, description="Agent context")
|
||||
actions: List[Action] = Field(default_factory=list)
|
||||
resources: List[AgentResource] = Field(default_factory=list)
|
||||
resource: Optional[Resource] = Field(None, description="Resource")
|
||||
llm_config: Optional[LLMConfig] = None
|
||||
resource_loader: Optional[ResourceLoader] = None
|
||||
max_retry_count: int = 3
|
||||
consecutive_auto_reply_counter: int = 0
|
||||
llm_client: Optional[AIWrapper] = None
|
||||
executor: Executor = Field(
|
||||
default_factory=lambda: ThreadPoolExecutor(max_workers=1),
|
||||
description="Executor for running tasks",
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new agent."""
|
||||
@@ -58,27 +62,12 @@ class ConversableAgent(Role, Agent):
|
||||
f"running!"
|
||||
)
|
||||
|
||||
# resource check
|
||||
for resource in self.resources:
|
||||
if (
|
||||
self.resource_loader is None
|
||||
or self.resource_loader.get_resource_api(
|
||||
resource.type, check_instance=False
|
||||
)
|
||||
is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Resource {resource.type}:{resource.value} missing resource loader"
|
||||
f" implementation,unable to read resources!"
|
||||
)
|
||||
|
||||
# action check
|
||||
if self.actions and len(self.actions) > 0:
|
||||
have_resource_types = [item.type for item in self.resources]
|
||||
for action in self.actions:
|
||||
if (
|
||||
action.resource_need
|
||||
and action.resource_need not in have_resource_types
|
||||
if action.resource_need and (
|
||||
not self.resource
|
||||
or not self.resource.get_resource_by_type(action.resource_need)
|
||||
):
|
||||
raise ValueError(
|
||||
f"{self.name}[{self.role}] Missing resources required for "
|
||||
@@ -112,13 +101,6 @@ class ConversableAgent(Role, Agent):
|
||||
raise ValueError("Agent context is not initialized!")
|
||||
return self.agent_context
|
||||
|
||||
@property
|
||||
def not_null_resource_loader(self) -> ResourceLoader:
|
||||
"""Get the resource loader."""
|
||||
if not self.resource_loader:
|
||||
raise ValueError("Resource loader is not initialized!")
|
||||
return self.resource_loader
|
||||
|
||||
@property
|
||||
def not_null_llm_config(self) -> LLMConfig:
|
||||
"""Get the LLM config."""
|
||||
@@ -134,23 +116,32 @@ class ConversableAgent(Role, Agent):
|
||||
raise ValueError("LLM client is not initialized!")
|
||||
return llm_client
|
||||
|
||||
async def blocking_func_to_async(
|
||||
self, func: Callable[..., Any], *args, **kwargs
|
||||
) -> Any:
|
||||
"""Run a potentially blocking function within an executor."""
|
||||
if not asyncio.iscoroutinefunction(func):
|
||||
return await blocking_func_to_async(self.executor, func, *args, **kwargs)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
async def preload_resource(self) -> None:
|
||||
"""Preload resources before agent initialization."""
|
||||
pass
|
||||
if self.resource:
|
||||
await self.blocking_func_to_async(self.resource.preload_resource)
|
||||
|
||||
async def build(self) -> "ConversableAgent":
|
||||
"""Build the agent."""
|
||||
# Preload resources
|
||||
await self.preload_resource()
|
||||
# Check if agent is available
|
||||
self.check_available()
|
||||
_language = self.not_null_agent_context.language
|
||||
if _language:
|
||||
self.language = _language
|
||||
|
||||
# Preload resources
|
||||
await self.preload_resource()
|
||||
# Initialize resource loader
|
||||
for action in self.actions:
|
||||
action.init_resource_loader(self.resource_loader)
|
||||
action.init_resource(self.resource)
|
||||
|
||||
# Initialize LLM Server
|
||||
if not self.is_human:
|
||||
@@ -175,13 +166,8 @@ class ConversableAgent(Role, Agent):
|
||||
raise ValueError("GptsMemory is not supported!")
|
||||
elif isinstance(target, AgentContext):
|
||||
self.agent_context = target
|
||||
elif isinstance(target, ResourceLoader):
|
||||
self.resource_loader = target
|
||||
elif isinstance(target, list) and target and len(target) > 0:
|
||||
if _is_list_of_type(target, Action):
|
||||
self.actions.extend(target)
|
||||
elif _is_list_of_type(target, AgentResource):
|
||||
self.resources = target
|
||||
elif isinstance(target, Resource):
|
||||
self.resource = target
|
||||
elif isinstance(target, AgentMemory):
|
||||
self.memory = target
|
||||
return self
|
||||
@@ -480,12 +466,12 @@ class ConversableAgent(Role, Agent):
|
||||
last_out: Optional[ActionOutput] = None
|
||||
for i, action in enumerate(self.actions):
|
||||
# Select the resources required by acton
|
||||
need_resource = None
|
||||
if self.resources and len(self.resources) > 0:
|
||||
for item in self.resources:
|
||||
if item.type == action.resource_need:
|
||||
need_resource = item
|
||||
break
|
||||
if action.resource_need and self.resource:
|
||||
need_resources = self.resource.get_resource_by_type(
|
||||
action.resource_need
|
||||
)
|
||||
else:
|
||||
need_resources = []
|
||||
|
||||
if not message:
|
||||
raise ValueError("The message content is empty!")
|
||||
@@ -497,7 +483,7 @@ class ConversableAgent(Role, Agent):
|
||||
"sender": sender.name if sender else None,
|
||||
"recipient": self.name,
|
||||
"reviewer": reviewer.name if reviewer else None,
|
||||
"need_resource": need_resource.to_dict() if need_resource else None,
|
||||
"need_resource": need_resources[0].name if need_resources else None,
|
||||
"rely_action_out": last_out.to_dict() if last_out else None,
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
"action_index": i,
|
||||
@@ -506,7 +492,7 @@ class ConversableAgent(Role, Agent):
|
||||
) as span:
|
||||
last_out = await action.run(
|
||||
ai_message=message,
|
||||
resource=need_resource,
|
||||
resource=None,
|
||||
rely_action_out=last_out,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -703,23 +689,11 @@ class ConversableAgent(Role, Agent):
|
||||
self, question: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate the resource variables."""
|
||||
resource_prompt_list = []
|
||||
for item in self.resources:
|
||||
resource_client = self.not_null_resource_loader.get_resource_api(
|
||||
item.type, ResourceClient
|
||||
resource_prompt = None
|
||||
if self.resource:
|
||||
resource_prompt = await self.resource.get_prompt(
|
||||
lang=self.language, question=question
|
||||
)
|
||||
if not resource_client:
|
||||
raise ValueError(
|
||||
f"Resource {item.type}:{item.value} missing resource loader"
|
||||
f" implementation,unable to read resources!"
|
||||
)
|
||||
resource_prompt_list.append(
|
||||
await resource_client.get_resource_prompt(item, question)
|
||||
)
|
||||
|
||||
resource_prompt = ""
|
||||
if len(resource_prompt_list) > 0:
|
||||
resource_prompt = "RESOURCES:" + "\n".join(resource_prompt_list)
|
||||
|
||||
out_schema: Optional[str] = ""
|
||||
if self.actions and len(self.actions) > 0:
|
||||
|
@@ -17,6 +17,7 @@ from dbgpt.core.interface.message import ModelMessageRoleType
|
||||
# TODO: Don't dependent on MixinLLMOperator
|
||||
from dbgpt.model.operators.llm_operator import MixinLLMOperator
|
||||
|
||||
from ....resource.manage import get_resource_manager
|
||||
from ....util.llm.llm import LLMConfig
|
||||
from ...agent import Agent, AgentGenerateContext, AgentMessage
|
||||
from ...agent_manage import get_agent_manager
|
||||
@@ -228,7 +229,6 @@ class AWELAgentOperator(
|
||||
silent=input_value.silent,
|
||||
memory=input_value.memory.structure_clone() if input_value.memory else None,
|
||||
agent_context=input_value.agent_context,
|
||||
resource_loader=input_value.resource_loader,
|
||||
llm_client=input_value.llm_client,
|
||||
round_index=agent.consecutive_auto_reply_counter,
|
||||
)
|
||||
@@ -262,13 +262,13 @@ class AWELAgentOperator(
|
||||
if self.awel_agent.fixed_subgoal:
|
||||
kwargs["fixed_subgoal"] = self.awel_agent.fixed_subgoal
|
||||
|
||||
resource = get_resource_manager().build_resource(self.awel_agent.resources)
|
||||
agent = (
|
||||
await agent_cls(**kwargs)
|
||||
.bind(input_value.memory)
|
||||
.bind(llm_config)
|
||||
.bind(input_value.agent_context)
|
||||
.bind(self.awel_agent.resources)
|
||||
.bind(input_value.resource_loader)
|
||||
.bind(resource)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
@@ -12,11 +12,17 @@ from dbgpt.core.awel.flow import (
|
||||
register_resource,
|
||||
)
|
||||
|
||||
from ....resource.resource_api import AgentResource, ResourceType
|
||||
from ....resource.base import AgentResource
|
||||
from ....resource.manage import get_resource_manager
|
||||
from ....util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from ...agent_manage import get_agent_manager
|
||||
|
||||
|
||||
def _load_resource_types():
|
||||
resources = get_resource_manager().get_supported_resources()
|
||||
return [OptionValue(label=item, name=item, value=item) for item in resources.keys()]
|
||||
|
||||
|
||||
@register_resource(
|
||||
label="AWEL Agent Resource",
|
||||
name="agent_operator_resource",
|
||||
@@ -29,10 +35,7 @@ from ...agent_manage import get_agent_manager
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
options=[
|
||||
OptionValue(label=item.name, name=item.value, value=item.value)
|
||||
for item in ResourceType
|
||||
],
|
||||
options=FunctionDynamicOptions(func=_load_resource_types),
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Agent Resource Name",
|
||||
@@ -70,7 +73,7 @@ class AWELAgentResource(AgentResource):
|
||||
value = values.pop("agent_resource_value")
|
||||
|
||||
values["name"] = name
|
||||
values["type"] = ResourceType(type)
|
||||
values["type"] = type
|
||||
values["value"] = value
|
||||
|
||||
return values
|
||||
|
@@ -132,7 +132,6 @@ class AWELBaseManager(ManagerAgent, ABC):
|
||||
reviewer=reviewer,
|
||||
memory=self.memory.structure_clone(),
|
||||
agent_context=self.agent_context,
|
||||
resource_loader=self.resource_loader,
|
||||
llm_client=self.not_null_llm_config.llm_client,
|
||||
)
|
||||
final_generate_context: AgentGenerateContext = await last_node.call(
|
||||
|
@@ -6,7 +6,7 @@ from typing import List, Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.vis.tags.vis_agent_plans import Vis, VisAgentPlans
|
||||
|
||||
from ...resource.resource_api import AgentResource
|
||||
from ...resource.base import AgentResource
|
||||
from ..action.base import Action, ActionOutput
|
||||
from ..agent import AgentContext
|
||||
from ..memory.gpts.base import GptsPlan
|
||||
|
@@ -4,6 +4,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
|
||||
from ...resource.pack import ResourcePack
|
||||
from ..agent import AgentMessage
|
||||
from ..base_agent import ConversableAgent
|
||||
from ..plan.plan_action import PlanAction
|
||||
@@ -152,9 +153,11 @@ assistants:[
|
||||
def bind_agents(self, agents: List[ConversableAgent]) -> ConversableAgent:
|
||||
"""Bind the agents to the planner agent."""
|
||||
self.agents = agents
|
||||
resources = []
|
||||
for agent in self.agents:
|
||||
if agent.resources and len(agent.resources) > 0:
|
||||
self.resources.extend(agent.resources)
|
||||
if agent.resource:
|
||||
resources.append(agent.resource)
|
||||
self.resource = ResourcePack(resources)
|
||||
return self
|
||||
|
||||
def prepare_act_param(self) -> Dict[str, Any]:
|
||||
|
@@ -188,7 +188,6 @@ class AutoPlanChatManager(ManagerAgent):
|
||||
.bind(self.memory)
|
||||
.bind(self.agent_context)
|
||||
.bind(self.llm_config)
|
||||
.bind(self.resource_loader)
|
||||
.bind_agents(self.agents)
|
||||
.build()
|
||||
)
|
||||
|
@@ -2,14 +2,14 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_json
|
||||
from dbgpt.vis.tags.vis_chart import Vis, VisChart
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_db_api import ResourceDbClient
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
from ...resource.database import DBResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,34 +69,28 @@ class ChartAction(Action[SqlInput]):
|
||||
content="The requested correctly structured answer could not be found.",
|
||||
)
|
||||
try:
|
||||
if not self.resource_loader:
|
||||
raise ValueError("ResourceLoader is not initialized!")
|
||||
resource_db_client: Optional[
|
||||
ResourceDbClient
|
||||
] = self.resource_loader.get_resource_api(
|
||||
self.resource_need, ResourceDbClient
|
||||
)
|
||||
if not resource_db_client:
|
||||
raise ValueError(
|
||||
"There is no implementation class bound to database resource "
|
||||
"execution!"
|
||||
)
|
||||
if not resource:
|
||||
raise ValueError("The data resource is not found!")
|
||||
data_df = await resource_db_client.query_to_df(resource.value, param.sql)
|
||||
if not self.resource_need:
|
||||
raise ValueError("The resource type is not found!")
|
||||
|
||||
if not self.render_protocol:
|
||||
raise ValueError("The rendering protocol is not initialized!")
|
||||
|
||||
db_resources: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
if not db_resources:
|
||||
raise ValueError("The database resource is not found!")
|
||||
|
||||
db = db_resources[0]
|
||||
data_df = await db.query_to_df(param.sql)
|
||||
view = await self.render_protocol.display(
|
||||
chart=json.loads(model_to_json(param)), data_df=data_df
|
||||
)
|
||||
if not self.resource_need:
|
||||
raise ValueError("The resource type is not found!")
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=model_to_json(param),
|
||||
view=view,
|
||||
resource_type=self.resource_need.value,
|
||||
resource_value=resource.value,
|
||||
resource_value=db._db_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Check your answers, the sql run failed!")
|
||||
|
@@ -8,7 +8,7 @@ from dbgpt.util.utils import colored
|
||||
from dbgpt.vis.tags.vis_code import Vis, VisCode
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource
|
||||
from ...resource.base import AgentResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -8,8 +8,8 @@ from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_db_api import ResourceDbClient
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
from ...resource.database import DBResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -83,29 +83,20 @@ class DashboardAction(Action[List[ChartItem]]):
|
||||
)
|
||||
chart_items: List[ChartItem] = input_param
|
||||
try:
|
||||
if not self.resource_loader:
|
||||
raise ValueError("Resource loader is not initialized!")
|
||||
resource_db_client: Optional[
|
||||
ResourceDbClient
|
||||
] = self.resource_loader.get_resource_api(
|
||||
self.resource_need, ResourceDbClient
|
||||
)
|
||||
if not resource_db_client:
|
||||
raise ValueError(
|
||||
"There is no implementation class bound to database resource "
|
||||
"execution!"
|
||||
)
|
||||
db_resources: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
if not db_resources:
|
||||
raise ValueError("The database resource is not found!")
|
||||
|
||||
if not resource:
|
||||
raise ValueError("Resource is not initialized!")
|
||||
db = db_resources[0]
|
||||
|
||||
if not db:
|
||||
raise ValueError("The database resource is not found!")
|
||||
|
||||
chart_params = []
|
||||
for chart_item in chart_items:
|
||||
chart_dict = {}
|
||||
try:
|
||||
sql_df = await resource_db_client.query_to_df(
|
||||
resource.value, chart_item.sql
|
||||
)
|
||||
sql_df = await db.query_to_df(chart_item.sql)
|
||||
chart_dict = chart_item.to_dict()
|
||||
|
||||
chart_dict["data"] = sql_df
|
||||
|
@@ -9,7 +9,7 @@ from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...core.schema import Status
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@@ -9,14 +9,13 @@ from dbgpt.vis.tags.vis_plugin import Vis, VisPlugin
|
||||
|
||||
from ...core.action.base import Action, ActionOutput
|
||||
from ...core.schema import Status
|
||||
from ...plugin.generator import PluginPromptGenerator
|
||||
from ...resource.resource_api import AgentResource, ResourceType
|
||||
from ...resource.resource_plugin_api import ResourcePluginClient
|
||||
from ...resource.base import AgentResource, ResourceType
|
||||
from ...resource.tool.pack import ToolPack
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginInput(BaseModel):
|
||||
class ToolInput(BaseModel):
|
||||
"""Plugin input model."""
|
||||
|
||||
tool_name: str = Field(
|
||||
@@ -32,8 +31,8 @@ class PluginInput(BaseModel):
|
||||
thought: str = Field(..., description="Summary of thoughts to the user")
|
||||
|
||||
|
||||
class PluginAction(Action[PluginInput]):
|
||||
"""Plugin action class."""
|
||||
class ToolAction(Action[ToolInput]):
|
||||
"""Tool action class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a plugin action."""
|
||||
@@ -43,7 +42,7 @@ class PluginAction(Action[PluginInput]):
|
||||
@property
|
||||
def resource_need(self) -> Optional[ResourceType]:
|
||||
"""Return the resource type needed for the action."""
|
||||
return ResourceType.Plugin
|
||||
return ResourceType.Tool
|
||||
|
||||
@property
|
||||
def render_protocol(self) -> Optional[Vis]:
|
||||
@@ -53,19 +52,19 @@ class PluginAction(Action[PluginInput]):
|
||||
@property
|
||||
def out_model_type(self):
|
||||
"""Return the output model type."""
|
||||
return PluginInput
|
||||
return ToolInput
|
||||
|
||||
@property
|
||||
def ai_out_schema(self) -> Optional[str]:
|
||||
"""Return the AI output schema."""
|
||||
out_put_schema = {
|
||||
"thought": "Summary of thoughts to the user",
|
||||
"tool_name": "The name of a tool that can be used to answer the current "
|
||||
"question or solve the current task.",
|
||||
"args": {
|
||||
"arg name1": "arg value1",
|
||||
"arg name2": "arg value2",
|
||||
},
|
||||
"thought": "Summary of thoughts to the user",
|
||||
}
|
||||
|
||||
return f"""Please response in the following json format:
|
||||
@@ -92,13 +91,8 @@ class PluginAction(Action[PluginInput]):
|
||||
need_vis_render (bool, optional): Whether need visualization rendering.
|
||||
Defaults to True.
|
||||
"""
|
||||
plugin_generator: Optional[PluginPromptGenerator] = kwargs.get(
|
||||
"plugin_generator", None
|
||||
)
|
||||
if not plugin_generator:
|
||||
raise ValueError("No plugin generator found!")
|
||||
try:
|
||||
param: PluginInput = self._input_convert(ai_message, PluginInput)
|
||||
param: ToolInput = self._input_convert(ai_message, ToolInput)
|
||||
except Exception as e:
|
||||
logger.exception((str(e)))
|
||||
return ActionOutput(
|
||||
@@ -107,21 +101,16 @@ class PluginAction(Action[PluginInput]):
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.resource_loader:
|
||||
raise ValueError("No resource_loader found!")
|
||||
resource_plugin_client: Optional[
|
||||
ResourcePluginClient
|
||||
] = self.resource_loader.get_resource_api(
|
||||
self.resource_need, ResourcePluginClient
|
||||
)
|
||||
if not resource_plugin_client:
|
||||
raise ValueError("No implementation of the use of plug-in resources!")
|
||||
tool_packs = ToolPack.from_resource(self.resource)
|
||||
if not tool_packs:
|
||||
raise ValueError("The tool resource is not found!")
|
||||
tool_pack = tool_packs[0]
|
||||
response_success = True
|
||||
status = Status.RUNNING.value
|
||||
err_msg = None
|
||||
try:
|
||||
tool_result = await resource_plugin_client.execute_command(
|
||||
param.tool_name, param.args, plugin_generator
|
||||
tool_result = await tool_pack.async_execute(
|
||||
resource_name=param.tool_name, **param.args
|
||||
)
|
||||
status = Status.COMPLETE.value
|
||||
except Exception as e:
|
||||
@@ -146,9 +135,9 @@ class PluginAction(Action[PluginInput]):
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=response_success,
|
||||
content=tool_result,
|
||||
content=str(tool_result),
|
||||
view=view,
|
||||
observations=tool_result,
|
||||
observations=str(tool_result),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Tool Action Run Failed!")
|
@@ -1,9 +1,11 @@
|
||||
"""Dashboard Assistant Agent."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from ..resource.database import DBResource
|
||||
from .actions.dashboard_action import DashboardAction
|
||||
|
||||
|
||||
@@ -58,15 +60,16 @@ class DashboardAssistantAgent(ConversableAgent):
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
client = self.not_null_resource_loader.get_resource_api(
|
||||
self.actions[0].resource_need, ResourceDbClient
|
||||
)
|
||||
if not client:
|
||||
|
||||
dbs: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
|
||||
if not dbs:
|
||||
raise ValueError(
|
||||
f"Resource type {self.actions[0].resource_need} is not supported."
|
||||
)
|
||||
db = dbs[0]
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": client.get_data_type(self.resources[0]),
|
||||
"dialect": db.dialect,
|
||||
}
|
||||
return reply_message
|
||||
|
@@ -2,14 +2,13 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Tuple, cast
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from ..core.action.base import ActionOutput
|
||||
from ..core.agent import AgentMessage
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..resource.resource_api import ResourceType
|
||||
from ..resource.resource_db_api import ResourceDbClient
|
||||
from ..resource.database import DBResource
|
||||
from .actions.chart_action import ChartAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -74,18 +73,21 @@ class DataScientistAgent(ConversableAgent):
|
||||
|
||||
def _init_reply_message(self, received_message: AgentMessage) -> AgentMessage:
|
||||
reply_message = super()._init_reply_message(received_message)
|
||||
client = self.not_null_resource_loader.get_resource_api(
|
||||
self.actions[0].resource_need, ResourceDbClient
|
||||
)
|
||||
if not client:
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": self.database.dialect,
|
||||
}
|
||||
return reply_message
|
||||
|
||||
@property
|
||||
def database(self) -> DBResource:
|
||||
"""Get the database resource."""
|
||||
dbs: List[DBResource] = DBResource.from_resource(self.resource)
|
||||
if not dbs:
|
||||
raise ValueError(
|
||||
f"Resource type {self.actions[0].resource_need} is not supported."
|
||||
)
|
||||
reply_message.context = {
|
||||
"display_type": self.actions[0].render_prompt(),
|
||||
"dialect": client.get_data_type(self.resources[0]),
|
||||
}
|
||||
return reply_message
|
||||
return dbs[0]
|
||||
|
||||
async def correctness_check(
|
||||
self, message: AgentMessage
|
||||
@@ -112,17 +114,6 @@ class DataScientistAgent(ConversableAgent):
|
||||
"generated is not found.",
|
||||
)
|
||||
try:
|
||||
resource_db_client: Optional[
|
||||
ResourceDbClient
|
||||
] = self.not_null_resource_loader.get_resource_api(
|
||||
ResourceType(action_out.resource_type), ResourceDbClient
|
||||
)
|
||||
if not resource_db_client:
|
||||
return (
|
||||
False,
|
||||
"Please check your answer, the data resource type is not "
|
||||
"supported.",
|
||||
)
|
||||
if not action_out.resource_value:
|
||||
return (
|
||||
False,
|
||||
@@ -130,8 +121,9 @@ class DataScientistAgent(ConversableAgent):
|
||||
"found.",
|
||||
)
|
||||
|
||||
columns, values = await resource_db_client.query(
|
||||
db=action_out.resource_value, sql=sql
|
||||
columns, values = await self.database.query(
|
||||
sql=sql,
|
||||
db=action_out.resource_value,
|
||||
)
|
||||
if not values or len(values) <= 0:
|
||||
return (
|
||||
|
1
dbgpt/agent/expand/resources/__init__.py
Normal file
1
dbgpt/agent/expand/resources/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Expand resources for the agent module."""
|
23
dbgpt/agent/expand/resources/dbgpt_tool.py
Normal file
23
dbgpt/agent/expand/resources/dbgpt_tool.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Some internal tools for the DB-GPT project."""
|
||||
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from ...resource.tool.base import tool
|
||||
|
||||
|
||||
@tool(description="List the supported models in DB-GPT project.")
|
||||
def list_dbgpt_support_models(
|
||||
model_type: Annotated[
|
||||
str, Doc("The model type, LLM(Large Language Model) and EMBEDDING).")
|
||||
] = "LLM",
|
||||
) -> str:
|
||||
"""List the supported models in dbgpt."""
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG
|
||||
|
||||
if model_type.lower() == "llm":
|
||||
supports = list(LLM_MODEL_CONFIG.keys())
|
||||
elif model_type.lower() == "embedding":
|
||||
supports = list(EMBEDDING_MODEL_CONFIG.keys())
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
return "\n\n".join(supports)
|
54
dbgpt/agent/expand/resources/search_tool.py
Normal file
54
dbgpt/agent/expand/resources/search_tool.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Search tools for the agent."""
|
||||
|
||||
import re
|
||||
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from ...resource.tool.base import tool
|
||||
|
||||
|
||||
@tool(
|
||||
description="Baidu search and return the results as a markdown string. Please set "
|
||||
"number of results not less than 8 for rich search results.",
|
||||
)
|
||||
def baidu_search(
|
||||
query: Annotated[str, Doc("The search query.")],
|
||||
num_results: Annotated[int, Doc("The number of search results to return.")] = 8,
|
||||
) -> str:
|
||||
"""Baidu search and return the results as a markdown string.
|
||||
|
||||
Please set number of results not less than 8 for rich search results.
|
||||
"""
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:112.0) "
|
||||
"Gecko/20100101 Firefox/112.0"
|
||||
}
|
||||
url = f"https://www.baidu.com/s?wd={query}&rn={num_results}"
|
||||
response = requests.get(url, headers=headers)
|
||||
response.encoding = "utf-8"
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
search_results = []
|
||||
for result in soup.find_all("div", class_=re.compile("^result c-container ")):
|
||||
title = result.find("h3", class_="t").get_text()
|
||||
link = result.find("a", href=True)["href"]
|
||||
snippet = result.find("span", class_=re.compile("^content-right_"))
|
||||
if snippet:
|
||||
snippet = snippet.get_text()
|
||||
else:
|
||||
snippet = ""
|
||||
search_results.append({"title": title, "href": link, "snippet": snippet})
|
||||
|
||||
return _search_to_view(search_results)
|
||||
|
||||
|
||||
def _search_to_view(results) -> str:
|
||||
view_results = []
|
||||
for item in results:
|
||||
view_results.append(
|
||||
f"### [{item['title']}]({item['href']})\n{item['snippet']}\n"
|
||||
)
|
||||
return "\n".join(view_results)
|
@@ -14,7 +14,7 @@ from ..core.action.base import Action, ActionOutput
|
||||
from ..core.agent import Agent, AgentMessage, AgentReviewInfo
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import ProfileConfig
|
||||
from ..resource.resource_api import AgentResource
|
||||
from ..resource.base import AgentResource
|
||||
from ..util.cmp import cmp_string_equal
|
||||
|
||||
try:
|
||||
|
@@ -1,22 +1,16 @@
|
||||
"""Plugin Assistant Agent."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..core.base_agent import ConversableAgent
|
||||
from ..core.profile import DynConfig, ProfileConfig
|
||||
from ..plugin.generator import PluginPromptGenerator
|
||||
from ..resource.resource_api import ResourceType
|
||||
from ..resource.resource_plugin_api import ResourcePluginClient
|
||||
from .actions.plugin_action import PluginAction
|
||||
from .actions.tool_action import ToolAction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginAssistantAgent(ConversableAgent):
|
||||
"""Plugin Assistant Agent."""
|
||||
|
||||
plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
class ToolAssistantAgent(ConversableAgent):
|
||||
"""Tool Assistant Agent."""
|
||||
|
||||
profile: ProfileConfig = ProfileConfig(
|
||||
name=DynConfig(
|
||||
@@ -57,37 +51,6 @@ class PluginAssistantAgent(ConversableAgent):
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new instance of PluginAssistantAgent."""
|
||||
"""Create a new instance of ToolAssistantAgent."""
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([PluginAction])
|
||||
|
||||
# @property
|
||||
# def introduce(self, **kwargs) -> str:
|
||||
# """Introduce the agent."""
|
||||
# if not self.plugin_generator:
|
||||
# raise ValueError("PluginGenerator is not loaded.")
|
||||
# return self.desc.format(
|
||||
# tool_infos=self.plugin_generator.generate_commands_string()
|
||||
# )
|
||||
|
||||
async def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
plugin_loader_client: ResourcePluginClient = (
|
||||
self.not_null_resource_loader.get_resource_api(
|
||||
ResourceType.Plugin, ResourcePluginClient
|
||||
)
|
||||
)
|
||||
item_list = []
|
||||
for item in self.resources:
|
||||
if item.type == ResourceType.Plugin:
|
||||
item_list.append(item.value)
|
||||
plugin_generator = self.plugin_generator
|
||||
for item in item_list:
|
||||
plugin_generator = await plugin_loader_client.load_plugin(
|
||||
item, plugin_generator
|
||||
)
|
||||
self.plugin_generator = plugin_generator
|
||||
|
||||
def prepare_act_param(self) -> Dict[str, Any]:
|
||||
"""Prepare the act parameter."""
|
||||
return {"plugin_generator": self.plugin_generator}
|
||||
self._init_actions([ToolAction])
|
@@ -1,6 +0,0 @@
|
||||
"""Plugin module for agent."""
|
||||
|
||||
from .commands.command_manage import CommandRegistry # noqa: F401
|
||||
from .generator import PluginPromptGenerator # noqa: F401
|
||||
|
||||
__ALL__ = ["PluginPromptGenerator", "CommandRegistry"]
|
@@ -1 +0,0 @@
|
||||
"""Commands Module."""
|
@@ -1 +0,0 @@
|
||||
"""Built-in commands for DB-GPT."""
|
@@ -1,2 +0,0 @@
|
||||
"""Visualize Data."""
|
||||
from .show_chart_gen import static_message_img_path # noqa: F401
|
@@ -1,354 +0,0 @@
|
||||
"""Chart display command implementation."""
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.util.string_utils import is_scientific_notation
|
||||
|
||||
from ...command_manage import command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas import DataFrame
|
||||
|
||||
static_message_img_path = os.path.join(PILOT_PATH, "message/img")
|
||||
|
||||
|
||||
def data_pre_classification(df: "DataFrame"):
|
||||
"""Return the x and y coordinates of the chart."""
|
||||
import pandas as pd
|
||||
|
||||
# Data pre-classification
|
||||
columns = df.columns.tolist()
|
||||
|
||||
number_columns = []
|
||||
non_numeric_colums = []
|
||||
|
||||
# Collect columns with less than 10 unique values
|
||||
non_numeric_colums_value_map = {}
|
||||
numeric_colums_value_map = {}
|
||||
for column_name in columns:
|
||||
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
|
||||
number_columns.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
else:
|
||||
non_numeric_colums.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
non_numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
|
||||
sorted_numeric_colums_value_map = dict(
|
||||
sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
|
||||
|
||||
sorted_colums_value_map = dict(
|
||||
sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||
|
||||
# Analyze x-coordinate
|
||||
if len(non_numeric_colums_sort_list) > 0:
|
||||
x_cloumn = non_numeric_colums_sort_list[-1]
|
||||
non_numeric_colums_sort_list.remove(x_cloumn)
|
||||
else:
|
||||
x_cloumn = number_columns[0]
|
||||
numeric_colums_sort_list.remove(x_cloumn)
|
||||
|
||||
# Analyze y-coordinate
|
||||
if len(numeric_colums_sort_list) > 0:
|
||||
y_column = numeric_colums_sort_list[0]
|
||||
numeric_colums_sort_list.remove(y_column)
|
||||
else:
|
||||
raise ValueError("Not enough numeric columns for chart!")
|
||||
|
||||
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
|
||||
|
||||
|
||||
def zh_font_set():
|
||||
"""Set Chinese font."""
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist) # noqa: C401
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
|
||||
def format_axis(value, pos):
|
||||
"""Format axis."""
|
||||
# Judge whether scientific counting is needed
|
||||
if is_scientific_notation(value):
|
||||
return "{:.2f}".format(value)
|
||||
return value
|
||||
|
||||
|
||||
@command(
|
||||
"response_line_chart",
|
||||
"Line chart display, used to display comparative trend analysis data",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_line_chart(df: "DataFrame") -> str:
|
||||
"""Response line chart."""
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
logger.info("response_line_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
try:
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist) # noqa: C401
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8)
|
||||
sns.set_palette("Set3")
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
# Complex line chart implementation
|
||||
if len(num_colmns) > 0:
|
||||
num_colmns.append(y)
|
||||
df_melted = pd.melt(
|
||||
df,
|
||||
id_vars=x,
|
||||
value_vars=num_colmns,
|
||||
var_name="line",
|
||||
value_name="Value",
|
||||
)
|
||||
sns.lineplot(
|
||||
data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2"
|
||||
)
|
||||
else:
|
||||
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
|
||||
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
|
||||
html_img = (
|
||||
"<img style='max-width: 100%; max-height: 70%;' "
|
||||
f'src="/images/{chart_name}" />'
|
||||
)
|
||||
return html_img
|
||||
except Exception as e:
|
||||
logging.error("Draw Line Chart failed!" + str(e))
|
||||
raise ValueError("Draw Line Chart failed!" + str(e))
|
||||
|
||||
|
||||
@command(
|
||||
"response_bar_chart",
|
||||
"Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_bar_chart(df: "DataFrame") -> str:
|
||||
"""Response bar chart."""
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
logger.info("response_bar_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist) # noqa: C401
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
# Fix the problem that the symbol cannot be displayed
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
# Fix chinese display problem
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8)
|
||||
# Set color theme
|
||||
sns.set_palette("Set3")
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
|
||||
hue = None
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
if len(non_num_columns) >= 1:
|
||||
hue = non_num_columns[0]
|
||||
|
||||
if len(num_colmns) >= 1:
|
||||
if hue:
|
||||
if len(num_colmns) >= 2:
|
||||
can_use_columns = num_colmns[:2]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
for sub_y_column in can_use_columns:
|
||||
sns.barplot(
|
||||
data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
if len(num_colmns) > 5:
|
||||
can_use_columns = num_colmns[:5]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
can_use_columns.append(y)
|
||||
|
||||
df_melted = pd.melt(
|
||||
df,
|
||||
id_vars=x,
|
||||
value_vars=can_use_columns,
|
||||
var_name="line",
|
||||
value_name="Value",
|
||||
)
|
||||
sns.barplot(
|
||||
data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
|
||||
# Set the y-axis scale format to normal number format
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
|
||||
|
||||
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
html_img = (
|
||||
"<img style='max-width: 100%; max-height: 70%;' "
|
||||
f'src="/images/{chart_name}" />'
|
||||
)
|
||||
return html_img
|
||||
|
||||
|
||||
@command(
|
||||
"response_pie_chart",
|
||||
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_pie_chart(df: "DataFrame") -> str:
|
||||
"""Response pie chart."""
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
logger.info("response_pie_chart")
|
||||
columns = df.columns.tolist()
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist) # noqa: C401
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
# Set the font style
|
||||
sns.set_palette("Set3")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
ax = df.plot(
|
||||
kind="pie",
|
||||
y=columns[1],
|
||||
ax=ax,
|
||||
labels=df[columns[0]].values,
|
||||
startangle=90,
|
||||
autopct="%1.1f%%",
|
||||
)
|
||||
|
||||
plt.axis("equal") # Make the pie chart a perfect circle
|
||||
# plt.title(columns[0])
|
||||
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100, transparent=True)
|
||||
|
||||
html_img = (
|
||||
"<img style='max-width: 100%; max-height: 70%;' "
|
||||
f'src="/images/{chart_name}" />'
|
||||
)
|
||||
|
||||
return html_img
|
@@ -1,24 +0,0 @@
|
||||
"""Generate a table display for the response."""
|
||||
import logging
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
from ...command_manage import command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"response_table",
|
||||
"Table display, suitable for display with many display columns or "
|
||||
"non-numeric columns",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_table(df: DataFrame) -> str:
|
||||
"""Response Table."""
|
||||
logger.info("response_table")
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
table_str = table_str.replace("\n", " ")
|
||||
html = f""" \n<div class="w-full overflow-auto">{table_str}</div>\n """
|
||||
return html
|
@@ -1,40 +0,0 @@
|
||||
"""Generate text display content for the data frame."""
|
||||
import logging
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
from ...command_manage import command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"response_data_text",
|
||||
"Text display, the default display method, suitable for single-line or "
|
||||
"simple content display",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_data_text(df: DataFrame) -> str:
|
||||
"""Generate text display content for the data frame."""
|
||||
logger.info("response_data_text")
|
||||
data = df.values
|
||||
|
||||
row_size = data.shape[0]
|
||||
value_str = ""
|
||||
text_info = ""
|
||||
if row_size > 1:
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
text_info = html.replace("\n", " ")
|
||||
elif row_size == 1:
|
||||
row = data[0]
|
||||
for value in row:
|
||||
if value_str:
|
||||
value_str = value_str + f", ** {value} **"
|
||||
else:
|
||||
value_str = f" ** {value} **"
|
||||
text_info = f" {value_str}"
|
||||
else:
|
||||
text_info = "##### No data found! #####"
|
||||
return text_info
|
@@ -1,169 +0,0 @@
|
||||
"""Command module."""
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
|
||||
from .exceptions import (
|
||||
CreateCommandException,
|
||||
ExecutionCommandException,
|
||||
NotCommandException,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_pathlike_command_args(command_args):
|
||||
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||
# todo
|
||||
command_args["directory"] = ""
|
||||
else:
|
||||
for pathlike in ["filename", "directory", "clone_path"]:
|
||||
if pathlike in command_args:
|
||||
# todo
|
||||
command_args[pathlike] = ""
|
||||
return command_args
|
||||
|
||||
|
||||
def execute_ai_response_json(
|
||||
prompt: PluginPromptGenerator,
|
||||
ai_response,
|
||||
user_input: str | None = None,
|
||||
) -> str:
|
||||
"""Execute the command from the AI response.
|
||||
|
||||
Args:
|
||||
prompt(PluginPromptGenerator): The prompt generator
|
||||
ai_response: The response from the AI
|
||||
user_input(str): The user input
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
"""
|
||||
from dbgpt.util.speech.say import say_text
|
||||
|
||||
cfg = Config()
|
||||
|
||||
command_name, arguments = get_command(ai_response)
|
||||
|
||||
if cfg.speak_mode:
|
||||
say_text(f"I want to execute {command_name}")
|
||||
|
||||
arguments = _resolve_pathlike_command_args(arguments)
|
||||
# Execute command
|
||||
if command_name is not None and command_name.lower().startswith("error"):
|
||||
result = f"Command {command_name} threw the following error: {arguments}"
|
||||
elif command_name == "human_feedback":
|
||||
result = f"Human feedback: {user_input}"
|
||||
else:
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_pre_command():
|
||||
continue
|
||||
command_name, arguments = plugin.pre_command(command_name, arguments)
|
||||
command_result = execute_command(
|
||||
command_name,
|
||||
arguments,
|
||||
prompt,
|
||||
)
|
||||
result = f"{command_result}"
|
||||
return result
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
plugin_generator: PluginPromptGenerator,
|
||||
) -> Any:
|
||||
"""Execute the command and return the result.
|
||||
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
plugin_generator (PluginPromptGenerator): The plugin generator
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
|
||||
Raises:
|
||||
NotCommandException: If the command is not found
|
||||
ExecutionCommandException: If an error occurs while executing the command
|
||||
"""
|
||||
cmd = None
|
||||
if plugin_generator.command_registry:
|
||||
cmd = plugin_generator.command_registry.commands.get(command_name)
|
||||
|
||||
# If the command is found, call it with the provided arguments
|
||||
if cmd:
|
||||
try:
|
||||
return cmd(**arguments)
|
||||
except Exception as e:
|
||||
raise CreateCommandException(f"Error: {str(e)}")
|
||||
# return f"Error: {str(e)}"
|
||||
# TODO: Change these to take in a file rather than pasted code, if
|
||||
# non-file is given, return instructions "Input should be a python
|
||||
# filepath, write your code to file and try again
|
||||
else:
|
||||
for command in plugin_generator.commands:
|
||||
if (
|
||||
command_name == command.label.lower()
|
||||
or command_name == command.name.lower()
|
||||
):
|
||||
try:
|
||||
# Delete non-defined parameters
|
||||
diff_ags = list(
|
||||
set(arguments.keys()).difference(set(command.args.keys()))
|
||||
)
|
||||
for arg_name in diff_ags:
|
||||
del arguments[arg_name]
|
||||
print(str(arguments))
|
||||
func = command.function
|
||||
if not func:
|
||||
raise ExecutionCommandException(
|
||||
f"Function not found for command: {command_name}"
|
||||
)
|
||||
return func(**arguments)
|
||||
except Exception as e:
|
||||
raise ExecutionCommandException(f"Execution error: {str(e)}")
|
||||
raise NotCommandException("Invalid command: " + command_name)
|
||||
|
||||
|
||||
def get_command(response_json: Dict):
|
||||
"""Create a command from the response JSON.
|
||||
|
||||
Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
response_json (json): The response from the AI
|
||||
|
||||
Returns:
|
||||
tuple: The command name and arguments
|
||||
|
||||
Raises:
|
||||
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
try:
|
||||
if "command" not in response_json:
|
||||
return "Error:", "Missing 'command' object in JSON"
|
||||
|
||||
if not isinstance(response_json, dict):
|
||||
return "Error:", f"'response_json' object is not dictionary {response_json}"
|
||||
|
||||
command = response_json["command"]
|
||||
if not isinstance(command, dict):
|
||||
return "Error:", "'command' object is not a dictionary"
|
||||
|
||||
if "name" not in command:
|
||||
return "Error:", "Missing 'name' field in 'command' object"
|
||||
|
||||
command_name = command["name"]
|
||||
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get("args", {})
|
||||
|
||||
return command_name, arguments
|
||||
except json.decoder.JSONDecodeError:
|
||||
return "Error:", "Invalid JSON"
|
||||
# All other errors, return "Error: + error message"
|
||||
except Exception as e:
|
||||
return "Error:", str(e)
|
@@ -1,35 +0,0 @@
|
||||
"""Exceptions for the commands plugin."""
|
||||
|
||||
|
||||
class CommandException(Exception):
|
||||
"""Common command error exception."""
|
||||
|
||||
def __init__(self, message: str, error_type: str = "Common Error"):
|
||||
"""Create a new CommandException instance."""
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
|
||||
|
||||
class CreateCommandException(CommandException):
|
||||
"""Create command error exception."""
|
||||
|
||||
def __init__(self, message: str, error_type="Create Command Error"):
|
||||
"""Create a new CreateCommandException instance."""
|
||||
super().__init__(message, error_type)
|
||||
|
||||
|
||||
class NotCommandException(CommandException):
|
||||
"""Command not found exception."""
|
||||
|
||||
def __init__(self, message: str, error_type="Not Command Error"):
|
||||
"""Create a new NotCommandException instance."""
|
||||
super().__init__(message, error_type)
|
||||
|
||||
|
||||
class ExecutionCommandException(CommandException):
|
||||
"""Command execution error exception."""
|
||||
|
||||
def __init__(self, message: str, error_type="Execution Command Error"):
|
||||
"""Create a new ExecutionCommandException instance."""
|
||||
super().__init__(message, error_type)
|
@@ -1,189 +0,0 @@
|
||||
"""A module for generating custom prompt strings."""
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .commands.command_manage import CommandRegistry
|
||||
|
||||
|
||||
class CommandEntry(BaseModel):
|
||||
"""CommandEntry class.
|
||||
|
||||
A class for storing information about a command.
|
||||
"""
|
||||
|
||||
label: str = Field(
|
||||
...,
|
||||
description="The label of the command.",
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="The name of the command.",
|
||||
)
|
||||
args: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="A dictionary containing argument names and their values.",
|
||||
)
|
||||
function: Optional[Callable] = Field(
|
||||
None,
|
||||
description="A callable function to be called when the command is executed.",
|
||||
)
|
||||
|
||||
|
||||
class PluginPromptGenerator:
|
||||
"""PluginPromptGenerator class.
|
||||
|
||||
A class for generating custom prompt strings based on constraints, commands,
|
||||
resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a new PromptGenerator object.
|
||||
|
||||
Initialize the PromptGenerator object with empty lists of constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
"""
|
||||
from .commands.command_manage import CommandRegistry
|
||||
|
||||
self._constraints: List[str] = []
|
||||
self._commands: List[CommandEntry] = []
|
||||
self._resources: List[str] = []
|
||||
self._performance_evaluation: List[str] = []
|
||||
self._command_registry: CommandRegistry = CommandRegistry()
|
||||
|
||||
@property
|
||||
def constraints(self) -> List[str]:
|
||||
"""Return the list of constraints."""
|
||||
return self._constraints
|
||||
|
||||
@property
|
||||
def commands(self) -> List[CommandEntry]:
|
||||
"""Return the list of commands."""
|
||||
return self._commands
|
||||
|
||||
@property
|
||||
def resources(self) -> List[str]:
|
||||
"""Return the list of resources."""
|
||||
return self._resources
|
||||
|
||||
@property
|
||||
def performance_evaluation(self) -> List[str]:
|
||||
"""Return the list of performance evaluations."""
|
||||
return self._performance_evaluation
|
||||
|
||||
@property
|
||||
def command_registry(self) -> "CommandRegistry":
|
||||
"""Return the command registry."""
|
||||
return self._command_registry
|
||||
|
||||
def set_command_registry(self, command_registry: "CommandRegistry") -> None:
|
||||
"""Set the command registry.
|
||||
|
||||
Args:
|
||||
command_registry: CommandRegistry
|
||||
"""
|
||||
self._command_registry = command_registry
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""Add a constraint to the constraints list.
|
||||
|
||||
Args:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
self._constraints.append(constraint)
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""Add a command to the commands.
|
||||
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
|
||||
Args:
|
||||
command_label (str): The label of the command.
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is None:
|
||||
args = {}
|
||||
|
||||
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||
|
||||
command = CommandEntry(
|
||||
label=command_label,
|
||||
name=command_name,
|
||||
args=command_args,
|
||||
function=function,
|
||||
)
|
||||
self._commands.append(command)
|
||||
|
||||
def _generate_command_string(self, command: CommandEntry) -> str:
|
||||
"""
|
||||
Generate a formatted string representation of a command.
|
||||
|
||||
Args:
|
||||
command (dict): A dictionary containing command information.
|
||||
|
||||
Returns:
|
||||
str: The formatted command string.
|
||||
"""
|
||||
args_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command.args.items()
|
||||
)
|
||||
return f'"{command.name}": {command.label} , args: {args_string}'
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Args:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
self._resources.append(resource)
|
||||
|
||||
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||
"""
|
||||
Add a performance evaluation item to the performance_evaluation list.
|
||||
|
||||
Args:
|
||||
evaluation (str): The evaluation item to be added.
|
||||
"""
|
||||
self._performance_evaluation.append(evaluation)
|
||||
|
||||
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
|
||||
"""
|
||||
Generate a numbered list from given items based on the item_type.
|
||||
|
||||
Args:
|
||||
items (list): A list of items to be numbered.
|
||||
item_type (str, optional): The type of items in the list.
|
||||
Defaults to 'list'.
|
||||
|
||||
Returns:
|
||||
str: The formatted numbered list.
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = []
|
||||
if self._command_registry:
|
||||
command_strings += [
|
||||
str(item)
|
||||
for item in self._command_registry.commands.values()
|
||||
if item.enabled
|
||||
]
|
||||
# terminate command is added manually
|
||||
command_strings += [self._generate_command_string(item) for item in items]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
def generate_commands_string(self) -> str:
|
||||
"""Return a formatted string representation of the commands list."""
|
||||
return f"{self._generate_numbered_list(self._commands, item_type='command')}"
|
@@ -1,36 +0,0 @@
|
||||
"""Plugin loader module."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from .generator import PluginPromptGenerator
|
||||
from .plugins_util import scan_plugins
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
"""Plugin Loader Class."""
|
||||
|
||||
def load_plugins(
|
||||
self, plugin_path: str, available_plugins: Optional[List[str]] = None
|
||||
) -> PluginPromptGenerator:
|
||||
"""Load plugins from plugin path."""
|
||||
available = available_plugins if available_plugins else ""
|
||||
logger.info(f"load_plugin path:{plugin_path}, available:{available}")
|
||||
plugins = scan_plugins(plugin_path)
|
||||
|
||||
generator: PluginPromptGenerator = PluginPromptGenerator()
|
||||
# load select plugin
|
||||
if available_plugins and len(available_plugins) > 0:
|
||||
for plugin in plugins:
|
||||
if plugin._name in available_plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
generator = plugin.post_prompt(generator)
|
||||
else:
|
||||
for plugin in plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
generator = plugin.post_prompt(generator)
|
||||
return generator
|
@@ -1,21 +1,49 @@
|
||||
"""Resource module for Agent."""
|
||||
from .resource_api import AgentResource, ResourceClient, ResourceType # noqa: F401
|
||||
from .resource_db_api import ResourceDbClient, SqliteLoadClient # noqa: F401
|
||||
from .resource_knowledge_api import ResourceKnowledgeClient # noqa: F401
|
||||
from .resource_loader import ResourceLoader # noqa: F401
|
||||
from .resource_plugin_api import ( # noqa: F401
|
||||
PluginFileLoadClient,
|
||||
ResourcePluginClient,
|
||||
"""Resource module for agent."""
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
AgentResource,
|
||||
Resource,
|
||||
ResourceParameters,
|
||||
ResourceType,
|
||||
)
|
||||
from .database import ( # noqa: F401
|
||||
DBParameters,
|
||||
DBResource,
|
||||
RDBMSConnectorResource,
|
||||
SQLiteDBResource,
|
||||
)
|
||||
from .knowledge import RetrieverResource, RetrieverResourceParameters # noqa: F401
|
||||
from .manage import ( # noqa: F401
|
||||
RegisterResource,
|
||||
ResourceManager,
|
||||
get_resource_manager,
|
||||
initialize_resource,
|
||||
)
|
||||
from .pack import PackResourceParameters, ResourcePack # noqa: F401
|
||||
from .tool.base import BaseTool, FunctionTool, ToolParameter, tool # noqa: F401
|
||||
from .tool.pack import AutoGPTPluginToolPack, ToolPack # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"AgentResource",
|
||||
"ResourceClient",
|
||||
"Resource",
|
||||
"ResourceParameters",
|
||||
"ResourceType",
|
||||
"ResourceDbClient",
|
||||
"SqliteLoadClient",
|
||||
"ResourceKnowledgeClient",
|
||||
"ResourceLoader",
|
||||
"PluginFileLoadClient",
|
||||
"ResourcePluginClient",
|
||||
"DBParameters",
|
||||
"DBResource",
|
||||
"RDBMSConnectorResource",
|
||||
"SQLiteDBResource",
|
||||
"RetrieverResource",
|
||||
"RetrieverResourceParameters",
|
||||
"RegisterResource",
|
||||
"ResourceManager",
|
||||
"get_resource_manager",
|
||||
"initialize_resource",
|
||||
"PackResourceParameters",
|
||||
"ResourcePack",
|
||||
"BaseTool",
|
||||
"FunctionTool",
|
||||
"ToolParameter",
|
||||
"tool",
|
||||
"AutoGPTPluginToolPack",
|
||||
"ToolPack",
|
||||
]
|
||||
|
240
dbgpt/agent/resource/base.py
Normal file
240
dbgpt/agent/resource/base.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Resources for the agent."""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, model_to_dict
|
||||
from dbgpt.util.parameter_utils import BaseParameters, _get_parameter_descriptions
|
||||
|
||||
P = TypeVar("P", bound="ResourceParameters")
|
||||
T = TypeVar("T", bound="Resource")
|
||||
|
||||
|
||||
class ResourceType(str, Enum):
|
||||
"""Resource type enumeration."""
|
||||
|
||||
DB = "database"
|
||||
Knowledge = "knowledge"
|
||||
Internet = "internet"
|
||||
Tool = "tool"
|
||||
Plugin = "plugin"
|
||||
TextFile = "text_file"
|
||||
ExcelFile = "excel_file"
|
||||
ImageFile = "image_file"
|
||||
AWELFlow = "awel_flow"
|
||||
# Resource type for resource pack
|
||||
Pack = "pack"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ResourceParameters(BaseParameters):
|
||||
"""Resource parameters class.
|
||||
|
||||
It defines the parameters for building a resource.
|
||||
"""
|
||||
|
||||
name: str = dataclasses.field(metadata={"help": "Resource name", "tags": "fixed"})
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v2"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls, parameters: Type["ResourceParameters"], version: Optional[str] = None
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
return _get_parameter_descriptions(parameters)
|
||||
|
||||
|
||||
class Resource(ABC, Generic[P]):
|
||||
"""Resource for the agent."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def type(cls) -> ResourceType:
|
||||
"""Return the resource type."""
|
||||
|
||||
@classmethod
|
||||
def type_alias(cls) -> str:
|
||||
"""Return the resource type alias."""
|
||||
return cls.type().value
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Return the resource name."""
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[P]:
|
||||
"""Return the parameters class."""
|
||||
return ResourceParameters
|
||||
|
||||
def prefer_resource_parameters_class(self) -> Type[P]:
|
||||
"""Return the parameters class.
|
||||
|
||||
You can override this method to return a different parameters class.
|
||||
It will be used to initialize the resource with parameters.
|
||||
"""
|
||||
return self.resource_parameters_class()
|
||||
|
||||
def initialize_with_parameters(self, resource_parameters: P):
|
||||
"""Initialize the resource with parameters."""
|
||||
pass
|
||||
|
||||
def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_resource(
|
||||
cls: Type[T],
|
||||
resource: Optional["Resource"],
|
||||
expected_type: Optional[ResourceType] = None,
|
||||
) -> List[T]:
|
||||
"""Create a resource from another resource.
|
||||
|
||||
Another resource can be a pack or a single resource, if it is a pack, it will
|
||||
return all resources which type is the same as the current resource.
|
||||
|
||||
Args:
|
||||
resource(Resource): The resource.
|
||||
expected_type(ResourceType): The expected resource type.
|
||||
Returns:
|
||||
List[Resource]: The resources.
|
||||
"""
|
||||
if not resource:
|
||||
return []
|
||||
typed_resources = []
|
||||
for r in resource.get_resource_by_type(expected_type or cls.type()):
|
||||
typed_resources.append(cast(T, r))
|
||||
return typed_resources
|
||||
|
||||
@abstractmethod
|
||||
async def get_prompt(
|
||||
self,
|
||||
*,
|
||||
lang: str = "en",
|
||||
prompt_type: str = "default",
|
||||
question: Optional[str] = None,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Get the prompt.
|
||||
|
||||
Args:
|
||||
lang(str): The language.
|
||||
prompt_type(str): The prompt type.
|
||||
question(str): The question.
|
||||
resource_name(str): The resource name, just for the pack, it will be used
|
||||
to select specific resource in the pack.
|
||||
"""
|
||||
|
||||
def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any:
|
||||
"""Execute the resource."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_execute(
|
||||
self, *args, resource_name: Optional[str] = None, **kwargs
|
||||
) -> Any:
|
||||
"""Execute the resource asynchronously."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the resource is asynchronous."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_pack(self) -> bool:
|
||||
"""Return whether the resource is a pack."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def sub_resources(self) -> List["Resource"]:
|
||||
"""Return the resources."""
|
||||
if not self.is_pack:
|
||||
raise ValueError("The resource is not a pack, no sub-resources.")
|
||||
return []
|
||||
|
||||
def get_resource_by_type(self, resource_type: ResourceType) -> List["Resource"]:
|
||||
"""Get resources by type.
|
||||
|
||||
If the resource is a pack, it will search the sub-resources. Otherwise, it will
|
||||
return itself if the type matches.
|
||||
|
||||
Args:
|
||||
resource_type(ResourceType): The resource type.
|
||||
|
||||
Returns:
|
||||
List[Resource]: The resources.
|
||||
"""
|
||||
if not self.is_pack:
|
||||
if self.type() == resource_type:
|
||||
return [self]
|
||||
else:
|
||||
return []
|
||||
resources = []
|
||||
for resource in self.sub_resources:
|
||||
if resource.type() == resource_type:
|
||||
resources.append(resource)
|
||||
return resources
|
||||
|
||||
|
||||
class AgentResource(BaseModel):
|
||||
"""Agent resource class."""
|
||||
|
||||
type: str
|
||||
name: str
|
||||
value: str
|
||||
is_dynamic: bool = (
|
||||
False # Is the current resource predefined or dynamically passed in?
|
||||
)
|
||||
|
||||
def resource_prompt_template(self, **kwargs) -> str:
|
||||
"""Get the resource prompt template."""
|
||||
return "{data_type} --{data_introduce}"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> Optional["AgentResource"]:
|
||||
"""Create an AgentResource object from a dictionary."""
|
||||
if d is None:
|
||||
return None
|
||||
return AgentResource(
|
||||
type=d.get("type"),
|
||||
name=d.get("name"),
|
||||
introduce=d.get("introduce"),
|
||||
value=d.get("value", None),
|
||||
is_dynamic=d.get("is_dynamic", False),
|
||||
parameters=d.get("parameters", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_json_list_str(d: Optional[str]) -> Optional[List["AgentResource"]]:
|
||||
"""Create a list of AgentResource objects from a json string."""
|
||||
if d is None:
|
||||
return None
|
||||
try:
|
||||
json_array = json.loads(d)
|
||||
except Exception:
|
||||
raise ValueError(f"Illegal AgentResource json string!{d}")
|
||||
if not isinstance(json_array, list):
|
||||
raise ValueError(f"Illegal AgentResource json string!{d}")
|
||||
json_list = []
|
||||
for item in json_array:
|
||||
r = AgentResource.from_dict(item)
|
||||
if r:
|
||||
json_list.append(r)
|
||||
return json_list
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the AgentResource object to a dictionary."""
|
||||
temp = model_to_dict(self)
|
||||
for field, value in temp.items():
|
||||
if isinstance(value, Enum):
|
||||
temp[field] = value.value
|
||||
return temp
|
203
dbgpt/agent/resource/database.py
Normal file
203
dbgpt/agent/resource/database.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Database resource module."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import Any, Generic, List, Optional, Tuple, Union
|
||||
|
||||
import cachetools
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.util.cache_utils import cached
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
from .base import P, Resource, ResourceParameters, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_PROMPT_TEMPLATE = (
|
||||
"Database type: {db_type}, related table structure definition: {schemas}"
|
||||
)
|
||||
_DEFAULT_PROMPT_TEMPLATE_ZH = "数据库类型:{db_type},相关表结构定义:{schemas}"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DBParameters(ResourceParameters):
|
||||
"""DB parameters class."""
|
||||
|
||||
db_name: str = dataclasses.field(metadata={"help": "DB name"})
|
||||
|
||||
|
||||
class DBResource(Resource[P], Generic[P]):
|
||||
"""Database resource class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
db_type: Optional[str] = None,
|
||||
db_name: Optional[str] = None,
|
||||
dialect: Optional[str] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
prompt_template: str = _DEFAULT_PROMPT_TEMPLATE,
|
||||
):
|
||||
"""Initialize the DB resource."""
|
||||
self._name = name
|
||||
self._db_type = db_type
|
||||
self._db_name = db_name
|
||||
self._dialect = dialect or db_type
|
||||
# Executor for running async tasks
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
self._prompt_template = prompt_template
|
||||
|
||||
@classmethod
|
||||
def type(cls) -> ResourceType:
|
||||
"""Return the resource type."""
|
||||
return ResourceType.DB
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the resource name."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def db_type(self) -> str:
|
||||
"""Return the resource name."""
|
||||
if not self._db_type:
|
||||
raise ValueError("Database type is not set.")
|
||||
return self._db_type
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return the resource name."""
|
||||
if not self._dialect:
|
||||
raise ValueError("Dialect is not set.")
|
||||
return self._dialect
|
||||
|
||||
@cached(cachetools.TTLCache(maxsize=100, ttl=10))
|
||||
async def get_prompt(
|
||||
self,
|
||||
*,
|
||||
lang: str = "en",
|
||||
prompt_type: str = "default",
|
||||
question: Optional[str] = None,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Get the prompt."""
|
||||
if not self._db_name:
|
||||
return "No database name provided."
|
||||
schema_info = await blocking_func_to_async(
|
||||
self._executor, self.get_schema_link, db=self._db_name, question=question
|
||||
)
|
||||
return self._prompt_template.format(db_type=self._db_type, schemas=schema_info)
|
||||
|
||||
def execute(self, *args, resource_name: Optional[str] = None, **kwargs) -> Any:
|
||||
"""Execute the resource."""
|
||||
copy_kwargs = kwargs.copy()
|
||||
if "db" not in copy_kwargs:
|
||||
copy_kwargs["db"] = self._db_name
|
||||
return self._sync_query(*args, **copy_kwargs)
|
||||
|
||||
async def async_execute(
|
||||
self, *args, resource_name: Optional[str] = None, **kwargs
|
||||
) -> Any:
|
||||
"""Execute the resource asynchronously."""
|
||||
copy_kwargs = kwargs.copy()
|
||||
if "db" not in copy_kwargs:
|
||||
copy_kwargs["db"] = self._db_name
|
||||
return await self.query(*args, **copy_kwargs)
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the resource is asynchronous."""
|
||||
return True
|
||||
|
||||
def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def query_to_df(self, sql: str, db: Optional[str] = None):
|
||||
"""Return the query result as a DataFrame."""
|
||||
import pandas as pd
|
||||
|
||||
field_names, result = await self.query(sql, db=db)
|
||||
return pd.DataFrame(result, columns=field_names)
|
||||
|
||||
async def query(self, sql: str, db: Optional[str] = None):
|
||||
"""Return the query result."""
|
||||
db_name = db or self._db_name
|
||||
return await blocking_func_to_async(
|
||||
self._executor, self._sync_query, db=db_name, sql=sql
|
||||
)
|
||||
|
||||
def _sync_query(self, db: str, sql: str):
|
||||
"""Return the query result."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
|
||||
class RDBMSConnectorResource(DBResource[DBParameters]):
|
||||
"""Connector resource class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
connector: Optional[RDBMSConnector] = None,
|
||||
db_name: Optional[str] = None,
|
||||
db_type: Optional[str] = None,
|
||||
dialect: Optional[str] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the connector resource."""
|
||||
if not db_type and connector:
|
||||
db_type = connector.db_type
|
||||
if not dialect and connector:
|
||||
dialect = connector.dialect
|
||||
if not db_name and connector:
|
||||
db_name = connector.get_current_db_name()
|
||||
self._connector = connector
|
||||
super().__init__(
|
||||
name,
|
||||
db_type=db_type,
|
||||
db_name=db_name,
|
||||
dialect=dialect,
|
||||
executor=executor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def connector(self) -> RDBMSConnector:
|
||||
"""Return the connector."""
|
||||
if not self._connector:
|
||||
raise ValueError("Connector is not set.")
|
||||
return self._connector
|
||||
|
||||
def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
return _parse_db_summary(self.connector)
|
||||
|
||||
def _sync_query(self, db: str, sql: str) -> Tuple[Tuple, List]:
|
||||
"""Return the query result."""
|
||||
result_lst = self.connector.run(sql)
|
||||
columns = result_lst[0]
|
||||
values = result_lst[1:]
|
||||
return columns, values
|
||||
|
||||
|
||||
class SQLiteDBResource(RDBMSConnectorResource):
|
||||
"""SQLite database resource class."""
|
||||
|
||||
def __init__(
|
||||
self, name: str, db_name: str, executor: Optional[Executor] = None, **kwargs
|
||||
):
|
||||
"""Initialize the SQLite database resource."""
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector
|
||||
|
||||
conn = SQLiteConnector.from_file_path(db_name)
|
||||
super().__init__(name, conn, executor=executor, **kwargs)
|
95
dbgpt/agent/resource/knowledge.py
Normal file
95
dbgpt/agent/resource/knowledge.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Knowledge resource."""
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Type
|
||||
|
||||
import cachetools
|
||||
|
||||
from dbgpt.util.cache_utils import cached
|
||||
|
||||
from .base import Resource, ResourceParameters, ResourceType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RetrieverResourceParameters(ResourceParameters):
|
||||
"""Retriever resource parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RetrieverResource(Resource[ResourceParameters]):
|
||||
"""Retriever resource.
|
||||
|
||||
Retrieve knowledge chunks from a retriever.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, retriever: "BaseRetriever"):
|
||||
"""Create a new RetrieverResource."""
|
||||
self._name = name
|
||||
self._retriever = retriever
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the resource name."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def retriever(self) -> "BaseRetriever":
|
||||
"""Return the retriever."""
|
||||
return self._retriever
|
||||
|
||||
@classmethod
|
||||
def type(cls) -> ResourceType:
|
||||
"""Return the resource type."""
|
||||
return ResourceType.Knowledge
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[ResourceParameters]:
|
||||
"""Return the resource parameters class."""
|
||||
return RetrieverResourceParameters
|
||||
|
||||
@cached(cachetools.TTLCache(maxsize=100, ttl=10))
|
||||
async def get_prompt(
|
||||
self,
|
||||
*,
|
||||
lang: str = "en",
|
||||
prompt_type: str = "default",
|
||||
question: Optional[str] = None,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""Get the prompt for the resource."""
|
||||
if not question:
|
||||
raise ValueError("Question is required for knowledge resource.")
|
||||
chunks = await self.retrieve(question)
|
||||
content = "\n".join([chunk.content for chunk in chunks])
|
||||
prompt_template = "known information: {content}"
|
||||
prompt_template_zh = "已知信息: {content}"
|
||||
if lang == "en":
|
||||
return prompt_template.format(content=content)
|
||||
return prompt_template_zh.format(content=content)
|
||||
|
||||
async def async_execute(
|
||||
self, *args, resource_name: Optional[str] = None, **kwargs
|
||||
) -> Any:
|
||||
"""Execute the resource asynchronously."""
|
||||
return await self.retrieve(*args, **kwargs)
|
||||
|
||||
async def retrieve(
|
||||
self, query: str, filters: Optional["MetadataFilters"] = None
|
||||
) -> List["Chunk"]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text.
|
||||
filters: (Optional[MetadataFilters]) metadata filters.
|
||||
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return await self.retriever.aretrieve(query, filters)
|
239
dbgpt/agent/resource/manage.py
Normal file
239
dbgpt/agent/resource/manage.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Resource manager."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
from .base import AgentResource, Resource, ResourceParameters, ResourceType
|
||||
from .pack import ResourcePack
|
||||
from .tool.pack import ToolResourceType, _is_function_tool, _to_tool_list
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterResource(BaseModel):
|
||||
"""Register resource model."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: Optional[str] = None
|
||||
resource_type: ResourceType
|
||||
resource_type_alias: Optional[str] = None
|
||||
resource_cls: Type[Resource]
|
||||
resource_instance: Optional[Resource] = None
|
||||
is_class: bool = True
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return the key."""
|
||||
full_cls = f"{self.resource_cls.__module__}.{self.resource_cls.__qualname__}"
|
||||
name = self.name or full_cls
|
||||
resource_type_alias = self.resource_type_alias or self.resource_type.value
|
||||
return f"{resource_type_alias}:{name}"
|
||||
|
||||
@property
|
||||
def type_unique_key(self) -> str:
|
||||
"""Return the key."""
|
||||
resource_type_alias = self.resource_type_alias or self.resource_type.value
|
||||
return resource_type_alias
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values):
|
||||
"""Pre-fill the model."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
resource_instance = values.get("resource_instance")
|
||||
if resource_instance is not None:
|
||||
values["name"] = values["name"] or resource_instance.name
|
||||
values["is_class"] = False
|
||||
if not isinstance(resource_instance, Resource):
|
||||
raise ValueError(
|
||||
f"resource_instance must be a Resource instance, not "
|
||||
f"{type(resource_instance)}"
|
||||
)
|
||||
if not values.get("resource_type"):
|
||||
values["resource_type"] = values["resource_cls"].type()
|
||||
if not values.get("resource_type_alias"):
|
||||
values["resource_type_alias"] = values["resource_cls"].type_alias()
|
||||
return values
|
||||
|
||||
def get_parameter_class(self) -> Type[ResourceParameters]:
|
||||
"""Return the parameter description."""
|
||||
if self.is_class:
|
||||
return self.resource_cls.resource_parameters_class()
|
||||
return self.resource_instance.prefer_resource_parameters_class() # type: ignore
|
||||
|
||||
|
||||
class ResourceManager(BaseComponent):
|
||||
"""Resource manager.
|
||||
|
||||
To manage the resources.
|
||||
"""
|
||||
|
||||
name = ComponentType.RESOURCE_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
"""Create a new AgentManager."""
|
||||
super().__init__(system_app)
|
||||
self.system_app = system_app
|
||||
self._resources: Dict[str, RegisterResource] = {}
|
||||
self._type_to_resources: Dict[str, List[RegisterResource]] = defaultdict(list)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the AgentManager."""
|
||||
self.system_app = system_app
|
||||
|
||||
def after_start(self):
|
||||
"""Register all resources."""
|
||||
# TODO: Register some internal resources
|
||||
pass
|
||||
|
||||
def register_resource(
|
||||
self,
|
||||
resource_cls: Optional[Type[Resource]] = None,
|
||||
resource_instance: Optional[Union[Resource, ToolResourceType]] = None,
|
||||
resource_type: Optional[ResourceType] = None,
|
||||
resource_type_alias: Optional[str] = None,
|
||||
):
|
||||
"""Register a resource."""
|
||||
if resource_instance and _is_function_tool(resource_instance):
|
||||
resource_instance = _to_tool_list(resource_instance)[0] # type: ignore
|
||||
|
||||
if resource_cls is None and resource_instance is None:
|
||||
raise ValueError("Resource class or instance must be provided.")
|
||||
name: Optional[str] = None
|
||||
if resource_instance is not None:
|
||||
resource_cls = resource_cls or type(resource_instance) # type: ignore
|
||||
name = resource_instance.name # type: ignore
|
||||
resource = RegisterResource(
|
||||
name=name,
|
||||
resource_cls=resource_cls,
|
||||
resource_instance=resource_instance,
|
||||
resource_type=resource_type,
|
||||
resource_type_alias=resource_type_alias,
|
||||
)
|
||||
self._resources[resource.key] = resource
|
||||
self._type_to_resources[resource.type_unique_key].append(resource)
|
||||
|
||||
def get_supported_resources(
|
||||
self, version: Optional[str] = None
|
||||
) -> Dict[str, List[ParameterDescription]]:
|
||||
"""Return the resources."""
|
||||
results = {}
|
||||
for key, resource in self._resources.items():
|
||||
parameter_class = resource.get_parameter_class()
|
||||
resource_type = resource.type_unique_key
|
||||
configs: Any = parameter_class.to_configurations(
|
||||
parameter_class, version=version
|
||||
)
|
||||
if (
|
||||
version == "v1"
|
||||
and isinstance(configs, list)
|
||||
and len(configs) > 0
|
||||
and isinstance(configs[0], ParameterDescription)
|
||||
):
|
||||
# v1, not compatible with class
|
||||
configs = []
|
||||
if not resource.is_class:
|
||||
for r in self._type_to_resources[resource_type]:
|
||||
if not r.is_class:
|
||||
configs.append(r.resource_instance.name) # type: ignore
|
||||
results[resource_type] = configs
|
||||
return results
|
||||
|
||||
def build_resource_by_type(
|
||||
self,
|
||||
type_unique_key: str,
|
||||
agent_resource: AgentResource,
|
||||
version: Optional[str] = None,
|
||||
) -> Resource:
|
||||
"""Return the resource by type."""
|
||||
item = self._type_to_resources.get(type_unique_key)
|
||||
if not item:
|
||||
raise ValueError(f"Resource type {type_unique_key} not found.")
|
||||
inst_items = [i for i in item if not i.is_class]
|
||||
if inst_items:
|
||||
if version == "v1":
|
||||
for i in inst_items:
|
||||
if (
|
||||
i.resource_instance
|
||||
and i.resource_instance.name == agent_resource.value
|
||||
):
|
||||
return i.resource_instance
|
||||
raise ValueError(
|
||||
f"Resource {agent_resource.value} not found in {type_unique_key}"
|
||||
)
|
||||
return cast(Resource, inst_items[0].resource_instance)
|
||||
elif len(inst_items) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple instances of resource {type_unique_key} found, "
|
||||
f"please specify the resource name."
|
||||
)
|
||||
else:
|
||||
single_item = item[0]
|
||||
try:
|
||||
parameter_cls = single_item.get_parameter_class()
|
||||
param = parameter_cls.from_dict(agent_resource.to_dict())
|
||||
resource_inst = single_item.resource_cls(**param.to_dict())
|
||||
return resource_inst
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build resource {single_item.key}: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Failed to build resource {single_item.key}: {str(e)}"
|
||||
)
|
||||
|
||||
def build_resource(
|
||||
self,
|
||||
agent_resources: Optional[List[AgentResource]] = None,
|
||||
version: Optional[str] = None,
|
||||
) -> Optional[Resource]:
|
||||
"""Build a resource.
|
||||
|
||||
If there is only one resource, return the resource instance, otherwise return a
|
||||
ResourcePack.
|
||||
|
||||
Args:
|
||||
agent_resources: The agent resources.
|
||||
version: The resource version.
|
||||
|
||||
Returns:
|
||||
Optional[Resource]: The resource instance.
|
||||
"""
|
||||
if not agent_resources:
|
||||
return None
|
||||
dependencies: List[Resource] = []
|
||||
for resource in agent_resources:
|
||||
resource_inst = self.build_resource_by_type(
|
||||
resource.type, resource, version=version
|
||||
)
|
||||
dependencies.append(resource_inst)
|
||||
if len(dependencies) == 1:
|
||||
return dependencies[0]
|
||||
else:
|
||||
return ResourcePack(dependencies)
|
||||
|
||||
|
||||
_SYSTEM_APP: Optional[SystemApp] = None
|
||||
|
||||
|
||||
def initialize_resource(system_app: SystemApp):
|
||||
"""Initialize the resource manager."""
|
||||
global _SYSTEM_APP
|
||||
_SYSTEM_APP = system_app
|
||||
resource_manager = ResourceManager(system_app)
|
||||
system_app.register_instance(resource_manager)
|
||||
|
||||
|
||||
def get_resource_manager(system_app: Optional[SystemApp] = None) -> ResourceManager:
|
||||
"""Return the resource manager."""
|
||||
if not _SYSTEM_APP:
|
||||
if not system_app:
|
||||
system_app = SystemApp()
|
||||
initialize_resource(system_app)
|
||||
app = system_app or _SYSTEM_APP
|
||||
return ResourceManager.get_instance(cast(SystemApp, app))
|
115
dbgpt/agent/resource/pack.py
Normal file
115
dbgpt/agent/resource/pack.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Resource pack module.
|
||||
|
||||
Resource pack is a collection of resources(also, it is a resource) that can be executed
|
||||
together.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import Resource, ResourceParameters, ResourceType
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PackResourceParameters(ResourceParameters):
|
||||
"""Resource pack parameters class."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResourcePack(Resource[PackResourceParameters]):
|
||||
"""Resource pack class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resources: List[Resource],
|
||||
name: str = "Resource Pack",
|
||||
prompt_separator: str = "\n",
|
||||
):
|
||||
"""Initialize the resource pack."""
|
||||
self._resources: Dict[str, Resource] = {
|
||||
resource.name: resource for resource in resources
|
||||
}
|
||||
self._name = name
|
||||
self._prompt_separator = prompt_separator
|
||||
|
||||
@classmethod
|
||||
def type(cls) -> ResourceType:
|
||||
"""Return the resource type."""
|
||||
return ResourceType.Pack
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the resource name."""
|
||||
return self._name
|
||||
|
||||
def _get_resource_by_name(self, name: str) -> Optional[Resource]:
|
||||
"""Get the resource by name."""
|
||||
return self._resources.get(name, None)
|
||||
|
||||
async def get_prompt(
|
||||
self,
|
||||
*,
|
||||
lang: str = "en",
|
||||
prompt_type: str = "default",
|
||||
question: Optional[str] = None,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""Get the prompt."""
|
||||
prompt_list = []
|
||||
for name, resource in self._resources.items():
|
||||
prompt = await resource.get_prompt(
|
||||
lang=lang,
|
||||
prompt_type=prompt_type,
|
||||
question=question,
|
||||
resource_name=resource_name,
|
||||
**kwargs,
|
||||
)
|
||||
prompt_list.append(prompt)
|
||||
return self._prompt_separator.join(prompt_list)
|
||||
|
||||
def append(self, resource: Resource, overwrite: bool = False):
|
||||
"""Append a resource to the pack."""
|
||||
name = resource.name
|
||||
if name in self._resources and not overwrite:
|
||||
raise ValueError(f"Resource {name} already exists in the pack.")
|
||||
self._resources[name] = resource
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*args,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute the resource."""
|
||||
if not resource_name:
|
||||
raise ValueError("No resource name provided, will not execute.")
|
||||
resource = self._resources.get(resource_name)
|
||||
if resource:
|
||||
return resource.execute(*args, **kwargs)
|
||||
raise ValueError("No resource parameters provided, will not execute.")
|
||||
|
||||
async def async_execute(
|
||||
self,
|
||||
*args,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute the resource asynchronously."""
|
||||
if not resource_name:
|
||||
raise ValueError("No resource name provided, will not execute.")
|
||||
resource = self._resources.get(resource_name)
|
||||
if resource:
|
||||
return await resource.async_execute(*args, **kwargs)
|
||||
raise ValueError("No resource parameters provided, will not execute.")
|
||||
|
||||
@property
|
||||
def is_pack(self) -> bool:
|
||||
"""Return whether the resource is a pack."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def sub_resources(self) -> List[Resource]:
|
||||
"""Return the resources."""
|
||||
return list(self._resources.values())
|
@@ -1,126 +0,0 @@
|
||||
"""Resource API for the agent."""
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, model_to_dict
|
||||
|
||||
|
||||
class ResourceType(Enum):
|
||||
"""Resource type enumeration."""
|
||||
|
||||
DB = "database"
|
||||
Knowledge = "knowledge"
|
||||
Internet = "internet"
|
||||
Plugin = "plugin"
|
||||
TextFile = "text_file"
|
||||
ExcelFile = "excel_file"
|
||||
ImageFile = "image_file"
|
||||
AWELFlow = "awel_flow"
|
||||
|
||||
|
||||
class AgentResource(BaseModel):
|
||||
"""Agent resource class."""
|
||||
|
||||
type: ResourceType
|
||||
name: str
|
||||
value: str
|
||||
is_dynamic: bool = (
|
||||
False # Is the current resource predefined or dynamically passed in?
|
||||
)
|
||||
|
||||
def resource_prompt_template(self, **kwargs) -> str:
|
||||
"""Get the resource prompt template."""
|
||||
return "{data_type} --{data_introduce}"
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict[str, Any]) -> Optional["AgentResource"]:
|
||||
"""Create an AgentResource object from a dictionary."""
|
||||
if d is None:
|
||||
return None
|
||||
return AgentResource(
|
||||
type=ResourceType(d.get("type")),
|
||||
name=d.get("name"),
|
||||
introduce=d.get("introduce"),
|
||||
value=d.get("value", None),
|
||||
is_dynamic=d.get("is_dynamic", False),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_json_list_str(d: Optional[str]) -> Optional[List["AgentResource"]]:
|
||||
"""Create a list of AgentResource objects from a json string."""
|
||||
if d is None:
|
||||
return None
|
||||
try:
|
||||
json_array = json.loads(d)
|
||||
except Exception:
|
||||
raise ValueError(f"Illegal AgentResource json string!{d}")
|
||||
if not isinstance(json_array, list):
|
||||
raise ValueError(f"Illegal AgentResource json string!{d}")
|
||||
json_list = []
|
||||
for item in json_array:
|
||||
r = AgentResource.from_dict(item)
|
||||
if r:
|
||||
json_list.append(r)
|
||||
return json_list
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the AgentResource object to a dictionary."""
|
||||
temp = model_to_dict(self)
|
||||
for field, value in temp.items():
|
||||
if isinstance(value, Enum):
|
||||
temp[field] = value.value
|
||||
return temp
|
||||
|
||||
|
||||
class ResourceClient(ABC):
|
||||
"""Resource client interface."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> ResourceType:
|
||||
"""Return the resource type."""
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Get the content introduction prompt of the specified resource.
|
||||
|
||||
Args:
|
||||
resource(AgentResource): The specified resource.
|
||||
question(str): The question to be asked.
|
||||
|
||||
Returns:
|
||||
str: The introduction content.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the specified resource.
|
||||
|
||||
Args:
|
||||
resource(AgentResource): The specified resource.
|
||||
|
||||
Returns:
|
||||
str: The data type.
|
||||
"""
|
||||
return ""
|
||||
|
||||
async def get_resource_prompt(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> str:
|
||||
"""Get the resource prompt.
|
||||
|
||||
Args:
|
||||
resource(AgentResource): The specified resource.
|
||||
question(str): The question to be asked.
|
||||
|
||||
Returns:
|
||||
str: The resource prompt.
|
||||
"""
|
||||
return resource.resource_prompt_template().format(
|
||||
data_type=self.get_data_type(resource),
|
||||
data_introduce=await self.get_data_introduce(resource, question),
|
||||
)
|
@@ -1,127 +0,0 @@
|
||||
"""Database resource client API."""
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
|
||||
|
||||
from .resource_api import AgentResource, ResourceClient, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourceDbClient(ResourceClient):
|
||||
"""Database resource client API."""
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return the resource type."""
|
||||
return ResourceType.DB
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the resource."""
|
||||
return super().get_data_type(resource)
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the data introduce of the resource."""
|
||||
return await self.get_schema_link(resource.value, question)
|
||||
|
||||
async def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def query_to_df(self, dbe: str, sql: str):
|
||||
"""Return the query result as a DataFrame."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def query(self, db: str, sql: str):
|
||||
"""Return the query result."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def run_sql(self, db: str, sql: str):
|
||||
"""Run the SQL."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
|
||||
class SqliteLoadClient(ResourceDbClient):
|
||||
"""SQLite resource client."""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
def __init__(self):
|
||||
"""Create a SQLite resource client."""
|
||||
super(SqliteLoadClient, self).__init__()
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the resource."""
|
||||
return "sqlite"
|
||||
|
||||
@contextmanager
|
||||
def connect(self, db) -> Iterator["Session"]:
|
||||
"""Connect to the database."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
engine = create_engine("sqlite:///" + db, echo=True)
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
from sqlalchemy import text
|
||||
|
||||
with self.connect(db) as connect:
|
||||
_tables_sql = """
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
cursor = connect.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results = []
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
PRAGMA table_info({table_name})
|
||||
"""
|
||||
cursor_colums = connect.execute(text(_sql))
|
||||
colum_results = cursor_colums.fetchall()
|
||||
table_colums = []
|
||||
for row_col in colum_results:
|
||||
field_info = list(row_col)
|
||||
table_colums.append(field_info[1])
|
||||
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
||||
|
||||
async def query_to_df(self, db: str, sql: str):
|
||||
"""Return the query result as a DataFrame."""
|
||||
import pandas as pd
|
||||
|
||||
field_names, result = await self.query(db, sql)
|
||||
return pd.DataFrame(result, columns=field_names)
|
||||
|
||||
async def query(self, db: str, sql: str):
|
||||
"""Return the query result."""
|
||||
from sqlalchemy import text
|
||||
|
||||
with self.connect(db) as connect:
|
||||
logger.info(f"Query[{sql}]")
|
||||
if not sql:
|
||||
return []
|
||||
cursor = connect.execute(text(sql))
|
||||
if cursor.returns_rows: # type: ignore
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
return field_names, result
|
@@ -1,23 +0,0 @@
|
||||
"""Knowledge resource API for the agent."""
|
||||
from typing import Any, Optional
|
||||
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
|
||||
|
||||
class ResourceKnowledgeClient(ResourceClient):
|
||||
"""Knowledge resource client."""
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return the resource type."""
|
||||
return ResourceType.Knowledge
|
||||
|
||||
async def get_kn(self, space_name: str, question: Optional[str] = None) -> Any:
|
||||
"""Get the knowledge content."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def add_kn(
|
||||
self, space_name: str, kn_name: str, type: str, content: Optional[Any]
|
||||
):
|
||||
"""Add knowledge content."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
@@ -1,40 +0,0 @@
|
||||
"""Resource loader module."""
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
|
||||
T = TypeVar("T", bound=ResourceClient)
|
||||
|
||||
|
||||
class ResourceLoader:
|
||||
"""Resource loader."""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a new resource loader."""
|
||||
self._resource_api_instance = defaultdict(ResourceClient)
|
||||
|
||||
def get_resource_api(
|
||||
self,
|
||||
resource_type: Optional[ResourceType],
|
||||
cls: Optional[Type[T]] = None,
|
||||
check_instance: bool = True,
|
||||
) -> Optional[T]:
|
||||
"""Get the resource loader for the given resource type."""
|
||||
if not resource_type:
|
||||
return None
|
||||
|
||||
if resource_type not in self._resource_api_instance:
|
||||
raise ValueError(
|
||||
f"No loader available for resource of type {resource_type.value}"
|
||||
)
|
||||
inst = self._resource_api_instance[resource_type]
|
||||
if check_instance and cls and not isinstance(inst, cls):
|
||||
raise ValueError(
|
||||
f"Resource loader for {resource_type.value} is not an instance of {cls}"
|
||||
)
|
||||
return inst
|
||||
|
||||
def register_resource_api(self, api_instance: ResourceClient):
|
||||
"""Register the resource API instance."""
|
||||
self._resource_api_instance[api_instance.type] = api_instance
|
@@ -1,90 +0,0 @@
|
||||
"""Resource plugin client API."""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from ..plugin.commands.command_manage import execute_command
|
||||
from ..plugin.generator import PluginPromptGenerator
|
||||
from ..plugin.plugins_util import scan_plugin_file, scan_plugins
|
||||
from ..resource.resource_api import AgentResource
|
||||
from .resource_api import ResourceClient, ResourceType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourcePluginClient(ResourceClient):
|
||||
"""Resource plugin client."""
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
"""Return the resource type."""
|
||||
return ResourceType.Plugin
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
"""Return the data type of the specified resource."""
|
||||
return "Tools"
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Get the content introduction prompt of the specified resource."""
|
||||
return await self.plugins_prompt(resource.value)
|
||||
|
||||
async def load_plugin(
|
||||
self,
|
||||
value: str,
|
||||
plugin_generator: Optional[PluginPromptGenerator] = None,
|
||||
) -> PluginPromptGenerator:
|
||||
"""Load the plugin."""
|
||||
raise NotImplementedError("The run method should be implemented in a subclass.")
|
||||
|
||||
async def plugins_prompt(
|
||||
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
) -> str:
|
||||
"""Get the plugin commands prompt."""
|
||||
plugin_generator = await self.load_plugin(value)
|
||||
return plugin_generator.generate_commands_string()
|
||||
|
||||
async def execute_command(
|
||||
self,
|
||||
command_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
plugin_generator: PluginPromptGenerator,
|
||||
):
|
||||
"""Execute the command."""
|
||||
if plugin_generator is None:
|
||||
raise ValueError("No plugin commands loaded into the executable!")
|
||||
return execute_command(command_name, arguments, plugin_generator)
|
||||
|
||||
|
||||
class PluginFileLoadClient(ResourcePluginClient):
|
||||
"""File plugin load client.
|
||||
|
||||
Load the plugin from the local file.
|
||||
"""
|
||||
|
||||
async def load_plugin(
|
||||
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
) -> PluginPromptGenerator:
|
||||
"""Load the plugin."""
|
||||
logger.info(f"PluginFileLoadClient load plugin:{value}")
|
||||
if plugin_generator is None:
|
||||
plugin_generator = PluginPromptGenerator()
|
||||
plugins = []
|
||||
if os.path.isabs(value):
|
||||
if not os.path.exists(value):
|
||||
raise ValueError(f"Wrong plugin file path configured {value}!")
|
||||
if os.path.isfile(value):
|
||||
plugins.extend(scan_plugin_file(value))
|
||||
else:
|
||||
plugins.extend(scan_plugins(value))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The current mode cannot support plug-in loading with relative "
|
||||
f"paths: {value}"
|
||||
)
|
||||
for plugin in plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
plugin_generator = plugin.post_prompt(plugin_generator)
|
||||
return cast(PluginPromptGenerator, plugin_generator)
|
4
dbgpt/agent/resource/tool/__init__.py
Normal file
4
dbgpt/agent/resource/tool/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Tool resources.
|
||||
|
||||
Tool is a special type of resource that is used to execute a function or a command.
|
||||
"""
|
1
dbgpt/agent/resource/tool/autogpt/__init__.py
Normal file
1
dbgpt/agent/resource/tool/autogpt/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Some compatible tools for autogpt."""
|
@@ -1,4 +1,8 @@
|
||||
"""Load plugins from a directory or a zip file."""
|
||||
"""Load plugins from a directory or a zip file.
|
||||
|
||||
This module provides utility functions to load auto_gpt plugins from a directory or a
|
||||
zip file.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import glob
|
||||
@@ -8,8 +12,6 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
@@ -131,30 +133,6 @@ def scan_plugins(
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||
"""Check if the plugin is in the allowlist or denylist.
|
||||
|
||||
Args:
|
||||
plugin_name (str): Name of the plugin.
|
||||
cfg (Config): Config object.
|
||||
|
||||
Returns:
|
||||
True or False
|
||||
"""
|
||||
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
|
||||
if plugin_name in cfg.plugins_denylist:
|
||||
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
|
||||
return False
|
||||
if plugin_name in cfg.plugins_allowlist:
|
||||
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
|
||||
return True
|
||||
ack = input(
|
||||
f"WARNING: Plugin {plugin_name} found. But not in the"
|
||||
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
|
||||
)
|
||||
return ack.lower() == cfg.authorise_key
|
||||
|
||||
|
||||
def update_from_git(
|
||||
download_path: str,
|
||||
github_repo: str = "",
|
366
dbgpt/agent/resource/tool/base.py
Normal file
366
dbgpt/agent/resource/tool/base.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Tool resources."""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_validator
|
||||
from dbgpt.util.configure.base import _MISSING, _MISSING_TYPE
|
||||
from dbgpt.util.function_utils import parse_param_description, type_to_string
|
||||
|
||||
from ..base import Resource, ResourceParameters, ResourceType
|
||||
|
||||
ToolFunc = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]
|
||||
|
||||
DB_GPT_TOOL_IDENTIFIER = "dbgpt_tool"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ToolResourceParameters(ResourceParameters):
|
||||
"""Tool resource parameters class."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""Parameter for a tool."""
|
||||
|
||||
name: str = Field(..., description="Parameter name")
|
||||
title: str = Field(
|
||||
...,
|
||||
description="Parameter title, default to the name with the first letter "
|
||||
"capitalized",
|
||||
)
|
||||
type: str = Field(..., description="Parameter type", examples=["string", "integer"])
|
||||
description: str = Field(..., description="Parameter description")
|
||||
required: bool = Field(True, description="Whether the parameter is required")
|
||||
default: Optional[Any] = Field(
|
||||
_MISSING, description="Default value for the parameter"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values):
|
||||
"""Pre-fill the model."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "title" not in values:
|
||||
values["title"] = values["name"].replace("_", " ").title()
|
||||
if "description" not in values:
|
||||
values["description"] = values["title"]
|
||||
return values
|
||||
|
||||
|
||||
class BaseTool(Resource[ToolResourceParameters], ABC):
|
||||
"""Base class for a tool."""
|
||||
|
||||
@classmethod
|
||||
def type(cls) -> ResourceType:
|
||||
"""Return the resource type."""
|
||||
return ResourceType.Tool
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Return the description of the tool."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def args(self) -> Dict[str, ToolParameter]:
|
||||
"""Return the arguments of the tool."""
|
||||
|
||||
async def get_prompt(
|
||||
self,
|
||||
*,
|
||||
lang: str = "en",
|
||||
prompt_type: str = "default",
|
||||
question: Optional[str] = None,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get the prompt."""
|
||||
prompt_template = (
|
||||
"{name}: Call this tool to interact with the {name} API. "
|
||||
"What is the {name} API useful for? {description} "
|
||||
"Parameters: {parameters}"
|
||||
)
|
||||
prompt_template_zh = (
|
||||
"{name}:调用此工具与 {name} API进行交互。{name} API 有什么用?{description} "
|
||||
"参数:{parameters}"
|
||||
)
|
||||
template = prompt_template if lang == "en" else prompt_template_zh
|
||||
if prompt_type == "openai":
|
||||
properties = {}
|
||||
required_list = []
|
||||
for key, value in self.args.items():
|
||||
properties[key] = {
|
||||
"type": value.type,
|
||||
"description": value.description,
|
||||
}
|
||||
if value.required:
|
||||
required_list.append(key)
|
||||
parameters_dict = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required_list,
|
||||
}
|
||||
parameters_string = json.dumps(parameters_dict, ensure_ascii=False)
|
||||
else:
|
||||
parameters = []
|
||||
for key, value in self.args.items():
|
||||
parameters.append(
|
||||
{
|
||||
"name": key,
|
||||
"type": value.type,
|
||||
"description": value.description,
|
||||
"required": value.required,
|
||||
}
|
||||
)
|
||||
parameters_string = json.dumps(parameters, ensure_ascii=False)
|
||||
return template.format(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters=parameters_string,
|
||||
)
|
||||
|
||||
|
||||
class FunctionTool(BaseTool):
|
||||
"""Function tool.
|
||||
|
||||
Wrap a function as a tool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
func: ToolFunc,
|
||||
description: Optional[str] = None,
|
||||
args: Optional[Dict[str, Union[ToolParameter, Dict[str, Any]]]] = None,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
):
|
||||
"""Create a tool from a function."""
|
||||
if not description:
|
||||
description = _parse_docstring(func)
|
||||
if not description:
|
||||
raise ValueError("The description is required")
|
||||
self._name = name
|
||||
self._description = cast(str, description)
|
||||
self._args: Dict[str, ToolParameter] = _parse_args(func, args, args_schema)
|
||||
self._func = func
|
||||
self._is_async = asyncio.iscoroutinefunction(func)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the name of the tool."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Return the description of the tool."""
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def args(self) -> Dict[str, ToolParameter]:
|
||||
"""Return the arguments of the tool."""
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Return whether the tool is asynchronous."""
|
||||
return self._is_async
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*args,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute the tool.
|
||||
|
||||
Args:
|
||||
*args: The positional arguments.
|
||||
resource_name (str, optional): The tool name to be executed(not used for
|
||||
specific tool).
|
||||
**kwargs: The keyword arguments.
|
||||
"""
|
||||
if self._is_async:
|
||||
raise ValueError("The function is asynchronous")
|
||||
return self._func(*args, **kwargs)
|
||||
|
||||
async def async_execute(
|
||||
self,
|
||||
*args,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute the tool asynchronously.
|
||||
|
||||
Args:
|
||||
*args: The positional arguments.
|
||||
resource_name (str, optional): The tool name to be executed(not used for
|
||||
specific tool).
|
||||
**kwargs: The keyword arguments.
|
||||
"""
|
||||
if not self._is_async:
|
||||
raise ValueError("The function is synchronous")
|
||||
return await self._func(*args, **kwargs)
|
||||
|
||||
|
||||
def tool(
|
||||
*decorator_args: Union[str, Callable],
|
||||
description: Optional[str] = None,
|
||||
args: Optional[Dict[str, Union[ToolParameter, Dict[str, Any]]]] = None,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Create a tool from a function."""
|
||||
|
||||
def _create_decorator(name: str):
|
||||
def decorator(func: ToolFunc):
|
||||
tool_name = name or func.__name__
|
||||
ft = FunctionTool(tool_name, func, description, args, args_schema)
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*f_args, **kwargs):
|
||||
return ft.execute(*f_args, **kwargs)
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*f_args, **kwargs):
|
||||
return await ft.async_execute(*f_args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
wrapper = async_wrapper
|
||||
else:
|
||||
wrapper = sync_wrapper
|
||||
wrapper._tool = ft # type: ignore
|
||||
setattr(wrapper, DB_GPT_TOOL_IDENTIFIER, True)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
if len(decorator_args) == 1 and callable(decorator_args[0]):
|
||||
# @tool
|
||||
old_func = decorator_args[0]
|
||||
return _create_decorator(old_func.__name__)(old_func)
|
||||
elif len(decorator_args) == 1 and isinstance(decorator_args[0], str):
|
||||
# @tool("google_search")
|
||||
return _create_decorator(decorator_args[0])
|
||||
elif (
|
||||
len(decorator_args) == 2
|
||||
and isinstance(decorator_args[0], str)
|
||||
and callable(decorator_args[1])
|
||||
):
|
||||
# @tool("google_search", description="Search on Google")
|
||||
return _create_decorator(decorator_args[0])(decorator_args[1])
|
||||
elif len(decorator_args) == 0:
|
||||
# use function name as tool name
|
||||
def _partial(func: ToolFunc):
|
||||
return _create_decorator(func.__name__)(func)
|
||||
|
||||
return _partial
|
||||
else:
|
||||
raise ValueError("Invalid usage of @tool")
|
||||
|
||||
|
||||
def _parse_docstring(func: ToolFunc) -> str:
|
||||
"""Parse the docstring of the function."""
|
||||
docstring = func.__doc__
|
||||
if docstring is None:
|
||||
return ""
|
||||
return docstring.strip()
|
||||
|
||||
|
||||
def _parse_args(
|
||||
func: ToolFunc,
|
||||
args: Optional[Dict[str, Union[ToolParameter, Dict[str, Any]]]] = None,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
) -> Dict[str, ToolParameter]:
|
||||
"""Parse the arguments of the function."""
|
||||
# Check args all values are ToolParameter
|
||||
parsed_args = {}
|
||||
if args is not None:
|
||||
if all(isinstance(v, ToolParameter) for v in args.values()):
|
||||
return args # type: ignore
|
||||
if all(isinstance(v, dict) for v in args.values()):
|
||||
for k, v in args.items():
|
||||
param_name = v.get("name", k)
|
||||
param_title = v.get("title", param_name.replace("_", " ").title())
|
||||
param_type = v["type"]
|
||||
param_description = v.get("description", param_title)
|
||||
param_default = v.get("default", _MISSING)
|
||||
param_required = v.get("required", param_default is _MISSING)
|
||||
parsed_args[k] = ToolParameter(
|
||||
name=param_name,
|
||||
title=param_title,
|
||||
type=param_type,
|
||||
description=param_description,
|
||||
default=param_default,
|
||||
required=param_required,
|
||||
)
|
||||
return parsed_args
|
||||
raise ValueError("args should be a dict of ToolParameter or dict")
|
||||
|
||||
if args_schema is not None:
|
||||
return _parse_args_from_schema(args_schema)
|
||||
signature = inspect.signature(func)
|
||||
|
||||
for param in signature.parameters.values():
|
||||
real_type = param.annotation
|
||||
param_name = param.name
|
||||
param_title = param_name.replace("_", " ").title()
|
||||
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
param_default = param.default
|
||||
param_required = False
|
||||
else:
|
||||
param_default = _MISSING
|
||||
param_required = True
|
||||
param_type = type_to_string(real_type, "unknown")
|
||||
param_description = parse_param_description(param_name, real_type)
|
||||
parsed_args[param_name] = ToolParameter(
|
||||
name=param_name,
|
||||
title=param_title,
|
||||
type=param_type,
|
||||
description=param_description,
|
||||
default=param_default,
|
||||
required=param_required,
|
||||
)
|
||||
return parsed_args
|
||||
|
||||
|
||||
def _parse_args_from_schema(args_schema: Type[BaseModel]) -> Dict[str, ToolParameter]:
|
||||
"""Parse the arguments from a Pydantic schema."""
|
||||
pydantic_args = args_schema.schema()["properties"]
|
||||
parsed_args = {}
|
||||
for key, value in pydantic_args.items():
|
||||
param_name = key
|
||||
param_title = value.get("title", param_name.replace("_", " ").title())
|
||||
if "type" in value:
|
||||
param_type = value["type"]
|
||||
elif "anyOf" in value:
|
||||
# {"anyOf": [{"type": "string"}, {"type": "null"}]}
|
||||
any_of: List[Dict[str, Any]] = value["anyOf"]
|
||||
if len(any_of) == 2 and any("null" in t["type"] for t in any_of):
|
||||
param_type = next(t["type"] for t in any_of if "null" not in t["type"])
|
||||
else:
|
||||
param_type = json.dumps({"anyOf": value["anyOf"]}, ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError(f"Invalid schema for {key}")
|
||||
param_description = value.get("description", param_title)
|
||||
param_default = value.get("default", _MISSING)
|
||||
param_required = False
|
||||
if isinstance(param_default, _MISSING_TYPE) and param_default == _MISSING:
|
||||
param_required = True
|
||||
|
||||
parsed_args[key] = ToolParameter(
|
||||
name=param_name,
|
||||
title=param_title,
|
||||
type=param_type,
|
||||
description=param_description,
|
||||
default=param_default,
|
||||
required=param_required,
|
||||
)
|
||||
return parsed_args
|
35
dbgpt/agent/resource/tool/exceptions.py
Normal file
35
dbgpt/agent/resource/tool/exceptions.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Exceptions for the tool."""
|
||||
|
||||
|
||||
class ToolException(Exception):
|
||||
"""Common tool error exception."""
|
||||
|
||||
def __init__(self, message: str, error_type: str = "Common Error"):
|
||||
"""Create a new ToolException instance."""
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
|
||||
|
||||
class CreateToolException(ToolException):
|
||||
"""Create tool error exception."""
|
||||
|
||||
def __init__(self, message: str, error_type="Create Command Error"):
|
||||
"""Create a new CreateToolException instance."""
|
||||
super().__init__(message, error_type)
|
||||
|
||||
|
||||
class ToolNotFoundException(ToolException):
|
||||
"""Tool not found exception."""
|
||||
|
||||
def __init__(self, message: str, error_type="Not Command Error"):
|
||||
"""Create a new ToolNotFoundException instance."""
|
||||
super().__init__(message, error_type)
|
||||
|
||||
|
||||
class ToolExecutionException(ToolException):
|
||||
"""Tool execution error exception."""
|
||||
|
||||
def __init__(self, message: str, error_type="Execution Command Error"):
|
||||
"""Create a new ToolExecutionException instance."""
|
||||
super().__init__(message, error_type)
|
213
dbgpt/agent/resource/tool/pack.py
Normal file
213
dbgpt/agent/resource/tool/pack.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Tool resource pack module."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
from ..base import ResourceType, T
|
||||
from ..pack import Resource, ResourcePack
|
||||
from .base import DB_GPT_TOOL_IDENTIFIER, BaseTool, FunctionTool, ToolFunc
|
||||
from .exceptions import ToolExecutionException, ToolNotFoundException
|
||||
|
||||
ToolResourceType = Union[BaseTool, List[BaseTool], ToolFunc, List[ToolFunc]]
|
||||
|
||||
|
||||
def _is_function_tool(resources: Any) -> bool:
|
||||
return (
|
||||
callable(resources)
|
||||
and hasattr(resources, DB_GPT_TOOL_IDENTIFIER)
|
||||
and getattr(resources, DB_GPT_TOOL_IDENTIFIER)
|
||||
and hasattr(resources, "_tool")
|
||||
and isinstance(getattr(resources, "_tool"), BaseTool)
|
||||
)
|
||||
|
||||
|
||||
def _to_tool_list(resources: ToolResourceType) -> List[BaseTool]:
|
||||
if isinstance(resources, BaseTool):
|
||||
return [resources]
|
||||
elif isinstance(resources, list) and all(
|
||||
isinstance(r, BaseTool) for r in resources
|
||||
):
|
||||
return cast(List[BaseTool], resources)
|
||||
elif isinstance(resources, list) and all(_is_function_tool(r) for r in resources):
|
||||
return [cast(FunctionTool, getattr(r, "_tool")) for r in resources]
|
||||
elif _is_function_tool(resources):
|
||||
function_tool = cast(FunctionTool, getattr(resources, "_tool"))
|
||||
return [function_tool]
|
||||
raise ValueError("Invalid tool resource type")
|
||||
|
||||
|
||||
class ToolPack(ResourcePack):
|
||||
"""Tool resource pack class."""
|
||||
|
||||
def __init__(
|
||||
self, resources: ToolResourceType, name: str = "Tool Resource Pack", **kwargs
|
||||
):
|
||||
"""Initialize the tool resource pack."""
|
||||
tools = cast(List[Resource], _to_tool_list(resources))
|
||||
super().__init__(resources=tools, name=name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_resource(
|
||||
cls: Type[T],
|
||||
resource: Optional[Resource],
|
||||
expected_type: Optional[ResourceType] = None,
|
||||
) -> List[T]:
|
||||
"""Create a resource from another resource."""
|
||||
if not resource:
|
||||
return []
|
||||
if isinstance(resource, ToolPack):
|
||||
return [cast(T, resource)]
|
||||
tools = super().from_resource(resource, ResourceType.Tool)
|
||||
if not tools:
|
||||
return []
|
||||
typed_tools = [cast(BaseTool, t) for t in tools]
|
||||
return [ToolPack(typed_tools)] # type: ignore
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""Add a command to the commands.
|
||||
|
||||
Compatible with the Auto-GPT old plugin system.
|
||||
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
|
||||
Args:
|
||||
command_label (str): The label of the command.
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is not None:
|
||||
tool_args = {}
|
||||
for name, value in args.items():
|
||||
tool_args[name] = {
|
||||
"name": name,
|
||||
"type": "str",
|
||||
"description": value,
|
||||
}
|
||||
else:
|
||||
tool_args = {}
|
||||
if not function:
|
||||
raise ValueError("Function must be provided")
|
||||
|
||||
ft = FunctionTool(
|
||||
name=command_name,
|
||||
func=function,
|
||||
args=tool_args,
|
||||
description=command_label,
|
||||
)
|
||||
self.append(ft)
|
||||
|
||||
def _get_execution_tool(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
) -> BaseTool:
|
||||
if not name and name not in self._resources:
|
||||
raise ToolNotFoundException("No tool found for execution")
|
||||
return cast(BaseTool, self._resources[name])
|
||||
|
||||
def _get_call_args(self, arguments: Dict[str, Any], tl: BaseTool) -> Dict[str, Any]:
|
||||
"""Get the call arguments."""
|
||||
# Delete non-defined parameters
|
||||
diff_args = list(set(arguments.keys()).difference(set(tl.args.keys())))
|
||||
for arg_name in diff_args:
|
||||
del arguments[arg_name]
|
||||
return arguments
|
||||
|
||||
def execute(
|
||||
self,
|
||||
*args,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute the tool.
|
||||
|
||||
Args:
|
||||
*args: The positional arguments.
|
||||
resource_name (str, optional): The tool name to be executed.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
Any: The result of the tool execution.
|
||||
"""
|
||||
tl = self._get_execution_tool(resource_name)
|
||||
try:
|
||||
arguments = {k: v for k, v in kwargs.items()}
|
||||
arguments = self._get_call_args(arguments, tl)
|
||||
if tl.is_async:
|
||||
raise ToolExecutionException("Async execution is not supported")
|
||||
else:
|
||||
return tl.execute(**arguments)
|
||||
except Exception as e:
|
||||
raise ToolExecutionException(f"Execution error: {str(e)}")
|
||||
|
||||
async def async_execute(
|
||||
self,
|
||||
*args,
|
||||
resource_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Execute the tool asynchronously.
|
||||
|
||||
Args:
|
||||
*args: The positional arguments.
|
||||
resource_name (str, optional): The tool name to be executed.
|
||||
**kwargs: The keyword arguments.
|
||||
|
||||
Returns:
|
||||
Any: The result of the tool execution.
|
||||
"""
|
||||
tl = self._get_execution_tool(resource_name)
|
||||
try:
|
||||
arguments = {k: v for k, v in kwargs.items()}
|
||||
arguments = self._get_call_args(arguments, tl)
|
||||
if tl.is_async:
|
||||
return await tl.async_execute(**arguments)
|
||||
else:
|
||||
# TODO: Execute in a separate executor
|
||||
return tl.execute(**arguments)
|
||||
except Exception as e:
|
||||
raise ToolExecutionException(f"Execution error: {str(e)}")
|
||||
|
||||
|
||||
class AutoGPTPluginToolPack(ToolPack):
|
||||
"""Auto-GPT plugin tool pack class."""
|
||||
|
||||
def __init__(self, plugin_path: Union[str, List[str]], **kwargs):
|
||||
"""Create an Auto-GPT plugin tool pack."""
|
||||
super().__init__([], **kwargs)
|
||||
self._plugin_path = plugin_path
|
||||
self._loaded = False
|
||||
|
||||
def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
from .autogpt.plugins_util import scan_plugin_file, scan_plugins
|
||||
|
||||
if self._loaded:
|
||||
return
|
||||
paths = (
|
||||
[self._plugin_path]
|
||||
if isinstance(self._plugin_path, str)
|
||||
else self._plugin_path
|
||||
)
|
||||
plugins = []
|
||||
for path in paths:
|
||||
if os.path.isabs(path):
|
||||
if not os.path.exists(path):
|
||||
raise ValueError(f"Wrong plugin path configured {path}!")
|
||||
if os.path.isfile(path):
|
||||
plugins.extend(scan_plugin_file(path))
|
||||
else:
|
||||
plugins.extend(scan_plugins(path))
|
||||
for plugin in plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
plugin.post_prompt(self)
|
||||
self._loaded = True
|
200
dbgpt/agent/resource/tool/tests/test_base_tool.py
Normal file
200
dbgpt/agent/resource/tool/tests/test_base_tool.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from ..base import BaseTool, FunctionTool, ToolParameter, tool
|
||||
|
||||
|
||||
class TestBaseTool(BaseTool):
|
||||
@property
|
||||
def name(self):
|
||||
return "test_tool"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "This is a test tool."
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return {}
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
return "executed"
|
||||
|
||||
async def async_execute(self, *args, **kwargs):
|
||||
return "async executed"
|
||||
|
||||
|
||||
def test_base_tool():
|
||||
tool = TestBaseTool()
|
||||
assert tool.name == "test_tool"
|
||||
assert tool.description == "This is a test tool."
|
||||
assert tool.execute() == "executed"
|
||||
assert asyncio.run(tool.async_execute()) == "async executed"
|
||||
|
||||
|
||||
def test_function_tool_sync() -> None:
|
||||
def two_sum(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
ft = FunctionTool(name="sample", func=two_sum)
|
||||
assert ft.execute(1, 2) == 3
|
||||
with pytest.raises(ValueError):
|
||||
asyncio.run(ft.async_execute(1, 2))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_tool_async() -> None:
|
||||
async def sample_async_func(a: int, b: int) -> int:
|
||||
"""Add two numbers asynchronously."""
|
||||
return a + b
|
||||
|
||||
ft = FunctionTool(name="sample_async", func=sample_async_func)
|
||||
with pytest.raises(ValueError):
|
||||
ft.execute(1, 2)
|
||||
assert await ft.async_execute(1, 2) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_tool_sync_with_args() -> None:
|
||||
def two_sum(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
ft = FunctionTool(
|
||||
name="sample",
|
||||
func=two_sum,
|
||||
args={
|
||||
"a": {"type": "integer", "name": "a", "description": "The first number."},
|
||||
"b": {"type": "integer", "name": "b", "description": "The second number."},
|
||||
},
|
||||
)
|
||||
ft1 = FunctionTool(
|
||||
name="sample",
|
||||
func=two_sum,
|
||||
args={
|
||||
"a": ToolParameter(
|
||||
type="integer", name="a", description="The first number."
|
||||
),
|
||||
"b": ToolParameter(
|
||||
type="integer", name="b", description="The second number."
|
||||
),
|
||||
},
|
||||
)
|
||||
assert ft.description == "Add two numbers."
|
||||
assert ft.args.keys() == {"a", "b"}
|
||||
assert ft.args["a"].type == "integer"
|
||||
assert ft.args["a"].name == "a"
|
||||
assert ft.args["a"].description == "The first number."
|
||||
assert ft.args["a"].title == "A"
|
||||
dict_params = [
|
||||
{
|
||||
"name": "a",
|
||||
"type": "integer",
|
||||
"description": "The first number.",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"type": "integer",
|
||||
"description": "The second number.",
|
||||
"required": True,
|
||||
},
|
||||
]
|
||||
json_params = json.dumps(dict_params, ensure_ascii=False)
|
||||
expected_prompt = (
|
||||
f"sample: Call this tool to interact with the sample API. What is the "
|
||||
f"sample API useful for? Add two numbers. Parameters: {json_params}"
|
||||
)
|
||||
assert await ft.get_prompt() == expected_prompt
|
||||
assert await ft1.get_prompt() == expected_prompt
|
||||
assert ft.execute(1, 2) == 3
|
||||
with pytest.raises(ValueError):
|
||||
await ft.async_execute(1, 2)
|
||||
|
||||
|
||||
def test_function_tool_sync_with_complex_types() -> None:
|
||||
@tool
|
||||
def complex_func(
|
||||
a: int,
|
||||
b: Annotated[int, Doc("The second number.")],
|
||||
c: Annotated[str, Doc("The third string.")],
|
||||
d: List[int],
|
||||
e: Annotated[Dict[str, int], Doc("A dictionary of integers.")],
|
||||
f: Optional[float] = None,
|
||||
g: str | None = None,
|
||||
) -> int:
|
||||
"""A complex function."""
|
||||
return (
|
||||
a + b + len(c) + sum(d) + sum(e.values()) + (f or 0) + (len(g) if g else 0)
|
||||
)
|
||||
|
||||
ft: FunctionTool = complex_func._tool
|
||||
assert ft.description == "A complex function."
|
||||
assert ft.args.keys() == {"a", "b", "c", "d", "e", "f", "g"}
|
||||
assert ft.args["a"].type == "integer"
|
||||
assert ft.args["a"].description == "A"
|
||||
assert ft.args["b"].type == "integer"
|
||||
assert ft.args["b"].description == "The second number."
|
||||
assert ft.args["c"].type == "string"
|
||||
assert ft.args["c"].description == "The third string."
|
||||
assert ft.args["d"].type == "array"
|
||||
assert ft.args["d"].description == "D"
|
||||
assert ft.args["e"].type == "object"
|
||||
assert ft.args["e"].description == "A dictionary of integers."
|
||||
assert ft.args["f"].type == "float"
|
||||
assert ft.args["f"].description == "F"
|
||||
assert ft.args["g"].type == "string"
|
||||
assert ft.args["g"].description == "G"
|
||||
|
||||
|
||||
def test_function_tool_sync_with_args_schema() -> None:
|
||||
class ArgsSchema(BaseModel):
|
||||
a: int = Field(description="The first number.")
|
||||
b: int = Field(description="The second number.")
|
||||
c: Optional[str] = Field(None, description="The third string.")
|
||||
d: List[int] = Field(description="Numbers.")
|
||||
|
||||
@tool(args_schema=ArgsSchema)
|
||||
def complex_func(a: int, b: int, c: Optional[str] = None) -> int:
|
||||
"""A complex function."""
|
||||
return a + b + len(c) if c else 0
|
||||
|
||||
ft: FunctionTool = complex_func._tool
|
||||
assert ft.description == "A complex function."
|
||||
assert ft.args.keys() == {"a", "b", "c", "d"}
|
||||
assert ft.args["a"].type == "integer"
|
||||
assert ft.args["a"].description == "The first number."
|
||||
assert ft.args["b"].type == "integer"
|
||||
assert ft.args["b"].description == "The second number."
|
||||
assert ft.args["c"].type == "string"
|
||||
assert ft.args["c"].description == "The third string."
|
||||
assert ft.args["d"].type == "array"
|
||||
assert ft.args["d"].description == "Numbers."
|
||||
|
||||
|
||||
def test_tool_decorator() -> None:
|
||||
@tool(description="Add two numbers")
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert add(1, 2) == 3
|
||||
assert add._tool.name == "add"
|
||||
assert add._tool.description == "Add two numbers"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_decorator_async() -> None:
|
||||
@tool
|
||||
async def async_add(a: int, b: int) -> int:
|
||||
"""Asynchronously add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert await async_add(1, 2) == 3
|
@@ -1,187 +1,19 @@
|
||||
"""Module for managing commands and command plugins."""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.util.json_utils import serialize
|
||||
from dbgpt.util.string_utils import extract_content, extract_content_open_ending
|
||||
|
||||
from .command import execute_command
|
||||
|
||||
# Unique identifier for auto-gpt commands
|
||||
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Command:
|
||||
"""A class representing a command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
signature (str): The signature of the function that the command executes.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
):
|
||||
"""Create a new Command object."""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Run the command."""
|
||||
if not self.enabled:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return self.method(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the Command object."""
|
||||
return f"{self.name}: {self.description}, args: {self.signature}"
|
||||
|
||||
|
||||
class CommandRegistry:
|
||||
"""Command registry class.
|
||||
|
||||
The CommandRegistry class is a manager for a collection of Command objects.
|
||||
It allows the registration, modification, and retrieval of Command objects,
|
||||
as well as the scanning and loading of command plugins from a specified
|
||||
directory.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a new CommandRegistry object."""
|
||||
self.commands = {}
|
||||
|
||||
def _import_module(self, module_name: str) -> Any:
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
def _reload_module(self, module: Any) -> Any:
|
||||
return importlib.reload(module)
|
||||
|
||||
def register(self, cmd: Command) -> None:
|
||||
"""Register a new Command object with the registry."""
|
||||
self.commands[cmd.name] = cmd
|
||||
|
||||
def unregister(self, command_name: str):
|
||||
"""Unregisters a Command object from the registry."""
|
||||
if command_name in self.commands:
|
||||
del self.commands[command_name]
|
||||
else:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
|
||||
def reload_commands(self) -> None:
|
||||
"""Reload all loaded command plugins."""
|
||||
for cmd_name in self.commands:
|
||||
cmd = self.commands[cmd_name]
|
||||
module = self._import_module(cmd.__module__)
|
||||
reloaded_module = self._reload_module(module)
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def is_valid_command(self, name: str) -> bool:
|
||||
"""Check if the specified command name is registered."""
|
||||
return name in self.commands
|
||||
|
||||
def get_command(self, name: str) -> Callable[..., Any]:
|
||||
"""Return the Command object with the specified name."""
|
||||
return self.commands[name]
|
||||
|
||||
def call(self, command_name: str, **kwargs) -> Any:
|
||||
"""Run command."""
|
||||
if command_name not in self.commands:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
command = self.commands[command_name]
|
||||
return command(**kwargs)
|
||||
|
||||
def command_prompt(self) -> str:
|
||||
"""Return a string representation of all registered `Command` objects."""
|
||||
commands_list = [
|
||||
f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())
|
||||
]
|
||||
return "\n".join(commands_list)
|
||||
|
||||
def import_commands(self, module_name: str) -> None:
|
||||
"""Import module.
|
||||
|
||||
Import the specified Python module containing command plugins.
|
||||
|
||||
This method imports the associated module and registers any functions or
|
||||
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
|
||||
as `Command` objects. The registered `Command` objects are then added to the
|
||||
`commands` dictionary of the `CommandRegistry` object.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import for command plugins.
|
||||
"""
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
# Register decorated functions
|
||||
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
|
||||
attr, AUTO_GPT_COMMAND_IDENTIFIER
|
||||
):
|
||||
self.register(attr.command)
|
||||
# Register command classes
|
||||
elif (
|
||||
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||
):
|
||||
cmd_instance = attr() # type: ignore
|
||||
self.register(cmd_instance)
|
||||
|
||||
|
||||
def command(
|
||||
name: str,
|
||||
description: str,
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Register a function as a command."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Command:
|
||||
cmd = Command(
|
||||
name=name,
|
||||
description=description,
|
||||
method=func,
|
||||
signature=signature,
|
||||
enabled=enabled,
|
||||
disabled_reason=disabled_reason,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper.command = cmd # type: ignore
|
||||
|
||||
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PluginStatus(BaseModel):
|
||||
"""A class representing the status of a plugin."""
|
||||
|
||||
@@ -399,28 +231,6 @@ class ApiCall:
|
||||
param["data"] = data
|
||||
return json.dumps(param, ensure_ascii=False)
|
||||
|
||||
def run(self, llm_text):
|
||||
"""Run the API calls."""
|
||||
if self._is_need_wait_plugin_call(
|
||||
llm_text
|
||||
) and self.check_last_plugin_call_ready(llm_text):
|
||||
# wait api call generate complete
|
||||
self.update_from_context(llm_text)
|
||||
for key, value in self.plugin_status_map.items():
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logger.info(f"Plugin execution:{value.name},{value.args}")
|
||||
try:
|
||||
value.api_result = execute_command(
|
||||
value.name, value.args, self.plugin_generator
|
||||
)
|
||||
value.status = Status.COMPLETE.value
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
return self.api_view_context(llm_text)
|
||||
|
||||
def run_display_sql(self, llm_text, sql_run_func):
|
||||
"""Run the API calls for displaying SQL data."""
|
||||
if self._is_need_wait_plugin_call(
|
@@ -31,7 +31,6 @@ def async_db_summary(system_app: SystemApp):
|
||||
|
||||
|
||||
def server_init(param: "WebServerParameters", system_app: SystemApp):
|
||||
from dbgpt.agent.plugin.commands.command_manage import CommandRegistry
|
||||
|
||||
# logger.info(f"args: {args}")
|
||||
# init config
|
||||
@@ -43,28 +42,6 @@ def server_init(param: "WebServerParameters", system_app: SystemApp):
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Loader plugins and commands
|
||||
command_categories = []
|
||||
# exclude commands
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
command_registry = CommandRegistry()
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry = command_registry
|
||||
|
||||
command_dispaly_commands = [
|
||||
"dbgpt.agent.plugin.commands.built_in.display_type.show_chart_gen",
|
||||
"dbgpt.agent.plugin.commands.built_in.display_type.show_table_gen",
|
||||
"dbgpt.agent.plugin.commands.built_in.display_type.show_text_gen",
|
||||
]
|
||||
command_dispaly_registry = CommandRegistry()
|
||||
for command in command_dispaly_commands:
|
||||
command_dispaly_registry.import_commands(command)
|
||||
cfg.command_display = command_dispaly_commands
|
||||
|
||||
|
||||
def _create_model_start_listener(system_app: SystemApp):
|
||||
def startup_event(wh):
|
||||
|
@@ -47,6 +47,8 @@ def initialize_components(
|
||||
)
|
||||
_initialize_model_cache(system_app)
|
||||
_initialize_awel(system_app, param)
|
||||
# Initialize resource manager of agent
|
||||
_initialize_resource_manager(system_app)
|
||||
_initialize_agent(system_app)
|
||||
_initialize_openapi(system_app)
|
||||
# Register serve apps
|
||||
@@ -85,6 +87,25 @@ def _initialize_agent(system_app: SystemApp):
|
||||
initialize_agent(system_app)
|
||||
|
||||
|
||||
def _initialize_resource_manager(system_app: SystemApp):
|
||||
from dbgpt.agent.expand.resources.dbgpt_tool import list_dbgpt_support_models
|
||||
from dbgpt.agent.expand.resources.search_tool import baidu_search
|
||||
from dbgpt.agent.resource.base import ResourceType
|
||||
from dbgpt.agent.resource.manage import get_resource_manager, initialize_resource
|
||||
from dbgpt.serve.agent.resource.datasource import DatasourceResource
|
||||
from dbgpt.serve.agent.resource.knowledge import KnowledgeSpaceRetrieverResource
|
||||
from dbgpt.serve.agent.resource.plugin import PluginToolPack
|
||||
|
||||
initialize_resource(system_app)
|
||||
rm = get_resource_manager(system_app)
|
||||
rm.register_resource(DatasourceResource)
|
||||
rm.register_resource(KnowledgeSpaceRetrieverResource)
|
||||
rm.register_resource(PluginToolPack, resource_type=ResourceType.Tool)
|
||||
# Register a search tool
|
||||
rm.register_resource(resource_instance=baidu_search)
|
||||
rm.register_resource(resource_instance=list_dbgpt_support_models)
|
||||
|
||||
|
||||
def _initialize_openapi(system_app: SystemApp):
|
||||
from dbgpt.app.openapi.api_v1.editor.service import EditorService
|
||||
|
||||
|
@@ -21,7 +21,12 @@ from dbgpt.app.base import (
|
||||
# initialize_components import time cost about 0.1s
|
||||
from dbgpt.app.component_configs import initialize_components
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG, LOGDIR
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
STATIC_MESSAGE_IMG_PATH,
|
||||
)
|
||||
from dbgpt.serve.core import add_exception_handler
|
||||
from dbgpt.util.fastapi import create_app, replace_router
|
||||
from dbgpt.util.i18n_utils import _, set_default_language
|
||||
@@ -88,14 +93,10 @@ def mount_routers(app: FastAPI):
|
||||
|
||||
|
||||
def mount_static_files(app: FastAPI):
|
||||
from dbgpt.agent.plugin.commands.built_in.display_type import (
|
||||
static_message_img_path,
|
||||
)
|
||||
|
||||
os.makedirs(static_message_img_path, exist_ok=True)
|
||||
os.makedirs(STATIC_MESSAGE_IMG_PATH, exist_ok=True)
|
||||
app.mount(
|
||||
"/images",
|
||||
StaticFiles(directory=static_message_img_path, html=True),
|
||||
StaticFiles(directory=STATIC_MESSAGE_IMG_PATH, html=True),
|
||||
name="static2",
|
||||
)
|
||||
app.mount(
|
||||
|
@@ -80,13 +80,6 @@ def get_db_list():
|
||||
return db_params
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
for plugin in CFG.plugins:
|
||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||
return plugins_infos
|
||||
|
||||
|
||||
def get_db_list_info():
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
params: dict = {}
|
||||
@@ -242,8 +235,6 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatDashboard.value() == chat_mode:
|
||||
return Result.succ(get_db_list())
|
||||
elif ChatScene.ChatExecution.value() == chat_mode:
|
||||
return Result.succ(plugins_select_info())
|
||||
elif ChatScene.ChatKnowledge.value() == chat_mode:
|
||||
return Result.succ(knowledge_list())
|
||||
elif ChatScene.ChatKnowledge.ExtractRefineSummary.value() == chat_mode:
|
||||
|
@@ -179,7 +179,7 @@ async def editor_chart_run(run_param: dict = Body()):
|
||||
end_time = time.time() * 1000
|
||||
sql_run_data: SqlRunData = SqlRunData(
|
||||
result_info="",
|
||||
run_cost=(end_time - start_time) / 1000,
|
||||
run_cost=int((end_time - start_time) / 1000),
|
||||
colunms=colunms,
|
||||
values=sql_result,
|
||||
)
|
||||
|
@@ -1,75 +0,0 @@
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.plugin.commands.command_manage import ApiCall
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.serve.agent.hub.controller import ModulePlugin
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger("chat_agent")
|
||||
|
||||
|
||||
class ChatAgent(BaseChat):
|
||||
"""Chat With Agent through plugin"""
|
||||
|
||||
chat_scene: str = ChatScene.ChatAgent.value()
|
||||
keep_end_rounds = 0
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
"""Chat Agent Module Initialization
|
||||
Args:
|
||||
- chat_param: Dict
|
||||
- chat_session_id: (str) chat session_id
|
||||
- current_user_input: (str) current user input
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) agent plugin
|
||||
"""
|
||||
if not chat_param["select_param"]:
|
||||
raise ValueError("Please select a Plugin!")
|
||||
self.select_plugins = chat_param["select_param"].split(",")
|
||||
|
||||
chat_param["chat_mode"] = ChatScene.ChatAgent
|
||||
super().__init__(chat_param=chat_param)
|
||||
self.plugins_prompt_generator: PluginPromptGenerator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.set_command_registry(CFG.command_registry)
|
||||
|
||||
# load select plugin
|
||||
agent_module = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
self.plugins_prompt_generator = agent_module.load_select_plugin(
|
||||
self.plugins_prompt_generator, self.select_plugins
|
||||
)
|
||||
|
||||
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict[str, str]:
|
||||
input_values = {
|
||||
"user_goal": self.current_user_input,
|
||||
"expand_constraints": self.__list_to_prompt_str(
|
||||
list(self.plugins_prompt_generator.constraints)
|
||||
),
|
||||
"tool_list": self.plugins_prompt_generator.generate_commands_string(),
|
||||
}
|
||||
return input_values
|
||||
|
||||
def stream_plugin_call(self, text):
|
||||
text = (
|
||||
text.replace("\\n", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\_", "_")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
with root_tracer.start_span(
|
||||
"ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
|
||||
):
|
||||
return self.api_call.run(text)
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
@@ -1,23 +0,0 @@
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
|
||||
## Two examples are defined by default
|
||||
EXAMPLES = [
|
||||
{
|
||||
"messages": [
|
||||
{"type": "human", "data": {"content": "查询xxx", "example": True}},
|
||||
{
|
||||
"type": "ai",
|
||||
"data": {
|
||||
"content": """{
|
||||
\"thoughts\": \"thought text\",
|
||||
\"speak\": \"thoughts summary to say to user\",
|
||||
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||
}""",
|
||||
"example": True,
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
@@ -1,20 +0,0 @@
|
||||
from typing import Dict, NamedTuple
|
||||
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
|
||||
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
speak: str = ""
|
||||
thoughts: str = ""
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||
### tool out data to table view
|
||||
print(f"parse_view_response:{speak},{str(data)}")
|
||||
view_text = f"##### {speak}" + "\n" + str(data)
|
||||
return view_text
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
@@ -1,82 +0,0 @@
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
|
||||
from dbgpt.app.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate, SystemPromptTemplate
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_PROMPT_SCENE_DEFINE_EN = "You are a universal AI assistant."
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
You need to analyze the user goals and, under the given constraints, prioritize using one of the following tools to solve the user goals.
|
||||
Tool list:
|
||||
{tool_list}
|
||||
Constraint:
|
||||
1. After finding the available tools from the tool list given below, please output the following content to use the tool. Please make sure that the following content only appears once in the output result:
|
||||
<api-call><name>Selected Tool name</name><args><arg1>value</arg1><arg2>value</arg2></args></api-call>
|
||||
2. Please generate the above call text according to the definition of the corresponding tool in the tool list. The reference case is as follows:
|
||||
Introduction to tool function: "Tool name", args: "Parameter 1": "<Parameter 1 value description>", "Parameter 2": "<Parameter 2 value description>" Corresponding call text: <api-call>< name>Tool name</name><args><parameter 1>value</parameter 1><parameter 2>value</parameter 2></args></api-call>
|
||||
3. Generate the call of each tool according to the above constraints. The prompt text for tool use needs to be generated before the tool is used.
|
||||
4. If the user goals cannot be understood and the intention is unclear, give priority to using search engine tools
|
||||
5. Parameter content may need to be inferred based on the user's goals, not just extracted from text
|
||||
6. Constraint conditions and tool information are used as auxiliary information for the reasoning process and should not be expressed in the output content to the user.
|
||||
{expand_constraints}
|
||||
User goals:
|
||||
{user_goal}
|
||||
"""
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = "你是一个通用AI助手!"
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
根据用户目标,请一步步思考,如何在满足下面约束条件的前提下,优先使用给出工具回答或者完成用户目标。
|
||||
|
||||
约束条件:
|
||||
1.从下面给定工具列表找到可用的工具后,请输出以下内容用来使用工具, 注意要确保下面内容在输出结果中只出现一次:
|
||||
<api-call><name>Selected Tool name</name><args><arg1>value</arg1><arg2>value</arg2></args></api-call>
|
||||
2.请根据工具列表对应工具的定义来生成上述调用文本, 参考案例如下:
|
||||
工具作用介绍: "工具名称", args: "参数1": "<参数1取值描述>","参数2": "<参数2取值描述>" 对应调用文本:<api-call><name>工具名称</name><args><参数1>value</参数1><参数2>value</参数2></args></api-call>
|
||||
3.根据上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成
|
||||
4.如果用户目标无法理解和意图不明确,优先使用搜索引擎工具
|
||||
5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取
|
||||
6.约束条件和工具信息作为推理过程的辅助信息,对应内容不要表达在给用户的输出内容中
|
||||
7.不要把<api-call></api-call>部分内容放在markdown标签里
|
||||
{expand_constraints}
|
||||
|
||||
工具列表:
|
||||
{tool_list}
|
||||
|
||||
用户目标:
|
||||
{user_goal}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
_PROMPT_SCENE_DEFINE = (
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
RESPONSE_FORMAT = None
|
||||
|
||||
|
||||
### Whether the model service is streaming output
|
||||
PROMPT_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(_PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE),
|
||||
HumanPromptTemplate.from_template("{user_goal}"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt_adapter = AppScenePromptTemplateAdapter(
|
||||
prompt=prompt,
|
||||
template_scene=ChatScene.ChatAgent.value(),
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
need_historical_messages=False,
|
||||
temperature=1,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
|
@@ -29,14 +29,15 @@ Give the correct {dialect} analysis SQL
|
||||
4.Carefully check the correctness of the SQL, the SQL must be correct, display method and summary of brief analysis thinking, and respond in the following json format:
|
||||
{response}
|
||||
The important thing is: Please make sure to only return the json string, do not add any other content (for direct processing by the program), and the json can be parsed by Python json.loads
|
||||
5. Please use the same language as the "user"
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = [
|
||||
{
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
"showcase": "What type of charts to show",
|
||||
"sql": "data analysis SQL",
|
||||
"title": "Data Analysis Title",
|
||||
"showcase": "What type of charts to show",
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
}
|
||||
]
|
||||
|
||||
|
@@ -3,7 +3,7 @@ import os
|
||||
from typing import Dict
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.plugin.commands.command_manage import ApiCall
|
||||
from dbgpt.agent.util.api_call import ApiCall
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
@@ -45,7 +45,7 @@ class ChatExcel(BaseChat):
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
|
||||
)
|
||||
)
|
||||
self.api_call = ApiCall(display_registry=CFG.command_display)
|
||||
self.api_call = ApiCall()
|
||||
super().__init__(chat_param=chat_param)
|
||||
|
||||
@trace()
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import Dict
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.plugin.commands.command_manage import ApiCall
|
||||
from dbgpt.agent.util.api_call import ApiCall
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
@@ -40,7 +40,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
self.database = CFG.local_db_manager.get_connector(self.db_name)
|
||||
|
||||
self.top_k: int = 50
|
||||
self.api_call = ApiCall(display_registry=CFG.command_display)
|
||||
self.api_call = ApiCall()
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
|
@@ -1,83 +0,0 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.plugin.commands.command import execute_command
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt.util.tracer import trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithPlugin(BaseChat):
|
||||
"""Chat With Plugin"""
|
||||
|
||||
chat_scene: str = ChatScene.ChatExecution.value()
|
||||
plugins_prompt_generator: PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
"""Chat Dashboard Module Initialization
|
||||
Args:
|
||||
- chat_param: Dict
|
||||
- chat_session_id: (str) chat session_id
|
||||
- current_user_input: (str) current user input
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) plugin selector
|
||||
"""
|
||||
self.plugin_selector = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatExecution
|
||||
super().__init__(chat_param=chat_param)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.set_command_registry(CFG.command_registry)
|
||||
# 加载插件中可用命令
|
||||
self.select_plugin = self.plugin_selector
|
||||
if self.select_plugin:
|
||||
for plugin in CFG.plugins:
|
||||
if plugin._name == self.plugin_selector:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
else:
|
||||
for plugin in CFG.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"constraints": self.__list_to_prompt_str(
|
||||
list(self.plugins_prompt_generator.constraints)
|
||||
),
|
||||
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
print(f"do_action:{prompt_response}")
|
||||
## plugin command run
|
||||
return execute_command(
|
||||
str(prompt_response.command.get("name")),
|
||||
prompt_response.command.get("args", {}),
|
||||
self.plugins_prompt_generator,
|
||||
)
|
||||
|
||||
def chat_show(self):
|
||||
super().chat_show()
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||
|
||||
def generate(self, p) -> str:
|
||||
return super().generate(p)
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatExecution.value
|
@@ -1,23 +0,0 @@
|
||||
from dbgpt.core._private.example_base import ExampleSelector
|
||||
|
||||
## Two examples are defined by default
|
||||
EXAMPLES = [
|
||||
{
|
||||
"messages": [
|
||||
{"type": "human", "data": {"content": "查询xxx", "example": True}},
|
||||
{
|
||||
"type": "ai",
|
||||
"data": {
|
||||
"content": """{
|
||||
\"thoughts\": \"thought text\",
|
||||
\"speak\": \"thoughts summary to say to user\",
|
||||
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||
}""",
|
||||
"example": True,
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
@@ -1,45 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, NamedTuple
|
||||
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, T
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
speak: str = ""
|
||||
thoughts: str = ""
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
clean_json_str = super().parse_prompt_response(model_out_text)
|
||||
print(clean_json_str)
|
||||
if not clean_json_str:
|
||||
raise ValueError("model server response not have json!")
|
||||
try:
|
||||
response = json.loads(clean_json_str)
|
||||
except Exception as e:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
|
||||
speak = ""
|
||||
thoughts = ""
|
||||
for key in sorted(response):
|
||||
if key.strip() == "command":
|
||||
command = response[key]
|
||||
if key.strip() == "thoughts":
|
||||
thoughts = response[key]
|
||||
if key.strip() == "speak":
|
||||
speak = response[key]
|
||||
return PluginAction(command, speak, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||
### tool out data to table view
|
||||
print(f"parse_view_response:{speak},{str(data)}")
|
||||
view_text = f"##### {speak}" + "\n" + str(data)
|
||||
return view_text
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
@@ -1,61 +0,0 @@
|
||||
import json
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import AppScenePromptTemplateAdapter, ChatScene
|
||||
from dbgpt.app.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Goals:
|
||||
{input}
|
||||
|
||||
Constraints:
|
||||
0.Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
{constraints}
|
||||
|
||||
Commands:
|
||||
{commands_infos}
|
||||
|
||||
Please response strictly according to the following json format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = {
|
||||
"thoughts": "thought text",
|
||||
"speak": "thoughts summary to say to user",
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
### Whether the model service is streaming output
|
||||
PROMPT_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(
|
||||
PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE,
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanPromptTemplate.from_template("{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt_adapter = AppScenePromptTemplateAdapter(
|
||||
prompt=prompt,
|
||||
template_scene=ChatScene.ChatExecution.value(),
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
need_historical_messages=False,
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt_adapter, is_default=True)
|
@@ -8,8 +8,6 @@ class ChatFactory(metaclass=Singleton):
|
||||
@staticmethod
|
||||
def get_implementation(chat_mode, **kwargs):
|
||||
# Lazy loading
|
||||
from dbgpt.app.scene.chat_agent.chat import ChatAgent
|
||||
from dbgpt.app.scene.chat_agent.prompt import prompt
|
||||
from dbgpt.app.scene.chat_dashboard.chat import ChatDashboard
|
||||
from dbgpt.app.scene.chat_dashboard.prompt import prompt
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||
@@ -19,8 +17,6 @@ class ChatFactory(metaclass=Singleton):
|
||||
from dbgpt.app.scene.chat_db.auto_execute.prompt import prompt
|
||||
from dbgpt.app.scene.chat_db.professional_qa.chat import ChatWithDbQA
|
||||
from dbgpt.app.scene.chat_db.professional_qa.prompt import prompt
|
||||
from dbgpt.app.scene.chat_execution.chat import ChatWithPlugin
|
||||
from dbgpt.app.scene.chat_execution.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.extract_entity.chat import ExtractEntity
|
||||
from dbgpt.app.scene.chat_knowledge.extract_entity.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""this module contains the schemas for the dbgpt client."""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -143,7 +144,7 @@ class AgentResourceType(Enum):
|
||||
class AgentResourceModel(BaseModel):
|
||||
"""Agent resource model."""
|
||||
|
||||
type: AgentResourceType
|
||||
type: str
|
||||
name: str
|
||||
value: str
|
||||
is_dynamic: bool = (
|
||||
@@ -156,7 +157,7 @@ class AgentResourceModel(BaseModel):
|
||||
if d is None:
|
||||
return None
|
||||
return AgentResourceModel(
|
||||
type=AgentResourceType(d.get("type")),
|
||||
type=d.get("type"),
|
||||
name=d.get("name"),
|
||||
introduce=d.get("introduce"),
|
||||
value=d.get("value", None),
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
Manages the lifecycle and registration of components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -87,6 +88,7 @@ class ComponentType(str, Enum):
|
||||
UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory"
|
||||
CONNECTOR_MANAGER = "dbgpt_connector_manager"
|
||||
AGENT_MANAGER = "dbgpt_agent_manager"
|
||||
RESOURCE_MANAGER = "dbgpt_resource_manager"
|
||||
|
||||
|
||||
_EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT"
|
||||
|
@@ -8,6 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||
LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
|
||||
STATIC_MESSAGE_IMG_PATH = os.path.join(PILOT_PATH, "message/img")
|
||||
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""SQLite connector."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
@@ -122,6 +123,18 @@ class SQLiteConnector(RDBMSConnector):
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
"""Get current database name.
|
||||
|
||||
Returns:
|
||||
str: database name
|
||||
"""
|
||||
full_path = self._engine.url.database
|
||||
db_name = os.path.basename(full_path)
|
||||
if db_name.endswith(".db"):
|
||||
db_name = db_name[:-3]
|
||||
return db_name
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
"""Get table simple info."""
|
||||
_tables_sql = """
|
||||
|
@@ -17,7 +17,8 @@ from dbgpt.agent.core.memory.gpts.gpts_memory import GptsMemory
|
||||
from dbgpt.agent.core.plan import AutoPlanChatManager, DefaultAWELLayoutManager
|
||||
from dbgpt.agent.core.schema import Status
|
||||
from dbgpt.agent.core.user_proxy_agent import UserProxyAgent
|
||||
from dbgpt.agent.resource.resource_loader import ResourceLoader
|
||||
from dbgpt.agent.resource.base import Resource
|
||||
from dbgpt.agent.resource.manage import get_resource_manager
|
||||
from dbgpt.agent.util.llm.llm import LLMConfig, LLMStrategyType
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.app.scene.base import ChatScene
|
||||
@@ -32,9 +33,6 @@ from dbgpt.util.json_utils import serialize
|
||||
from ..db.gpts_app import GptsApp, GptsAppDao, GptsAppQuery
|
||||
from ..db.gpts_conversations_db import GptsConversationsDao, GptsConversationsEntity
|
||||
from ..db.gpts_manage_db import GptsInstanceEntity
|
||||
from ..resource_loader.datasource_load_client import DatasourceLoadClient
|
||||
from ..resource_loader.knowledge_space_load_client import KnowledgeSpaceLoadClient
|
||||
from ..resource_loader.plugin_hub_load_client import PluginHubLoadClient
|
||||
from ..team.base import TeamMode
|
||||
from .db_gpts_memory import MetaDbGptsMessageMemory, MetaDbGptsPlansMemory
|
||||
|
||||
@@ -93,8 +91,6 @@ class MultiAgents(BaseComponent, ABC):
|
||||
from dbgpt.agent.core.memory.hybrid import HybridMemory
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
memory_key = f"{dbgpts_name}_{conv_id}"
|
||||
if memory_key in self.agent_memory_map:
|
||||
@@ -105,13 +101,17 @@ class MultiAgents(BaseComponent, ABC):
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vstore_name = f"_chroma_agent_memory_{dbgpts_name}_{conv_id}"
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=VectorStoreConfig(
|
||||
name=vstore_name, embedding_fn=embedding_fn
|
||||
),
|
||||
# Just use chroma store now
|
||||
# vector_store_connector = VectorStoreConnector(
|
||||
# vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
# vector_store_config=VectorStoreConfig(
|
||||
# name=vstore_name, embedding_fn=embedding_fn
|
||||
# ),
|
||||
# )
|
||||
memory = HybridMemory[AgentMemoryFragment].from_chroma(
|
||||
vstore_name=vstore_name,
|
||||
embeddings=embedding_fn,
|
||||
)
|
||||
memory = HybridMemory[AgentMemoryFragment].from_vstore(vector_store_connector)
|
||||
agent_memory = AgentMemory(memory, gpts_memory=self.memory)
|
||||
self.agent_memory_map[memory_key] = agent_memory
|
||||
return agent_memory
|
||||
@@ -243,14 +243,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
agent_memory: Optional[AgentMemory] = None,
|
||||
):
|
||||
employees: List[Agent] = []
|
||||
# Prepare resource loader
|
||||
resource_loader = ResourceLoader()
|
||||
plugin_hub_loader = PluginHubLoadClient()
|
||||
resource_loader.register_resource_api(plugin_hub_loader)
|
||||
datasource_loader = DatasourceLoadClient()
|
||||
resource_loader.register_resource_api(datasource_loader)
|
||||
knowledge_space_loader = KnowledgeSpaceLoadClient()
|
||||
resource_loader.register_resource_api(knowledge_space_loader)
|
||||
rm = get_resource_manager()
|
||||
context: AgentContext = AgentContext(
|
||||
conv_id=conv_uid,
|
||||
gpts_app_name=gpts_app.app_name,
|
||||
@@ -264,6 +257,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
).create()
|
||||
self.llm_provider = DefaultLLMClient(worker_manager, auto_convert_message=True)
|
||||
|
||||
depend_resource: Optional[Resource] = None
|
||||
for record in gpts_app.details:
|
||||
cls: Type[ConversableAgent] = get_agent_manager().get_by_name(
|
||||
record.agent_name
|
||||
@@ -273,12 +267,13 @@ class MultiAgents(BaseComponent, ABC):
|
||||
llm_strategy=LLMStrategyType(record.llm_strategy),
|
||||
strategy_context=record.llm_strategy_value,
|
||||
)
|
||||
depend_resource = rm.build_resource(record.resources, version="v1")
|
||||
|
||||
agent = (
|
||||
await cls()
|
||||
.bind(context)
|
||||
.bind(llm_config)
|
||||
.bind(record.resources)
|
||||
.bind(resource_loader)
|
||||
.bind(depend_resource)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
@@ -298,7 +293,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
manager = (
|
||||
await manager.bind(context)
|
||||
.bind(llm_config)
|
||||
.bind(resource_loader)
|
||||
.bind(depend_resource)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
@@ -308,7 +303,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
user_proxy: UserProxyAgent = (
|
||||
await UserProxyAgent()
|
||||
.bind(context)
|
||||
.bind(resource_loader)
|
||||
.bind(depend_resource)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
|
@@ -4,10 +4,8 @@ from fastapi import APIRouter
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.core.agent_manage import get_agent_manager
|
||||
from dbgpt.agent.resource.resource_api import ResourceType
|
||||
from dbgpt.agent.resource.manage import get_resource_manager
|
||||
from dbgpt.agent.util.llm.llm import LLMStrategyType
|
||||
from dbgpt.app.knowledge.api import knowledge_space_service
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.serve.agent.app.gpts_server import available_llms
|
||||
from dbgpt.serve.agent.db.gpts_app import (
|
||||
@@ -17,7 +15,6 @@ from dbgpt.serve.agent.db.gpts_app import (
|
||||
GptsAppQuery,
|
||||
GptsAppResponse,
|
||||
)
|
||||
from dbgpt.serve.agent.hub.plugin_hub import plugin_hub
|
||||
from dbgpt.serve.agent.team.base import TeamMode
|
||||
|
||||
CFG = Config()
|
||||
@@ -109,7 +106,8 @@ async def team_mode_list():
|
||||
@router.get("/v1/resource-type/list")
|
||||
async def team_mode_list():
|
||||
try:
|
||||
return Result.succ([type.value for type in ResourceType])
|
||||
resources = get_resource_manager().get_supported_resources(version="v1")
|
||||
return Result.succ(list(resources.keys()))
|
||||
except Exception as ex:
|
||||
return Result.failed(code="E000X", msg=f"query resource type list error: {ex}")
|
||||
|
||||
@@ -146,29 +144,8 @@ async def app_resources(
|
||||
Get agent resources, such as db, knowledge, internet, plugin.
|
||||
"""
|
||||
try:
|
||||
results = []
|
||||
match type:
|
||||
case ResourceType.DB.value:
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
results = [db["db_name"] for db in dbs]
|
||||
if name:
|
||||
results = [r for r in results if name in r]
|
||||
case ResourceType.Knowledge.value:
|
||||
knowledge_spaces = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest()
|
||||
)
|
||||
results = [ks.name for ks in knowledge_spaces]
|
||||
if name:
|
||||
results = [r for r in results if name in r]
|
||||
case ResourceType.Plugin.value:
|
||||
plugins = plugin_hub.get_my_plugin(user_code)
|
||||
results = [plugin.name for plugin in plugins]
|
||||
if name:
|
||||
results = [r for r in results if name in r]
|
||||
case ResourceType.Internet.value:
|
||||
return Result.succ(None)
|
||||
case ResourceType.File.value:
|
||||
return Result.succ(None)
|
||||
resources = get_resource_manager().get_supported_resources("v1")
|
||||
results = resources.get(type, [])
|
||||
return Result.succ(results)
|
||||
except Exception as ex:
|
||||
return Result.failed(code="E000X", msg=f"query app resources error: {ex}")
|
||||
|
@@ -15,7 +15,7 @@ from dbgpt._private.pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.agent.core.plan import AWELTeamContext
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from dbgpt.agent.resource.base import AgentResource
|
||||
from dbgpt.serve.agent.team.base import TeamMode
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
@@ -4,8 +4,8 @@ from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, File, UploadFile
|
||||
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
from dbgpt.agent.plugin.plugins_util import scan_plugins
|
||||
from dbgpt.agent.resource.tool.autogpt.plugins_util import scan_plugins
|
||||
from dbgpt.agent.resource.tool.pack import AutoGPTPluginToolPack
|
||||
from dbgpt.app.openapi.api_view_model import Result
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.configs.model_config import PLUGINS_DIR
|
||||
@@ -30,25 +30,15 @@ class ModulePlugin(BaseComponent, ABC):
|
||||
|
||||
def __init__(self):
|
||||
# load plugins
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
self.refresh_plugins()
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
|
||||
|
||||
def refresh_plugins(self):
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def load_select_plugin(
|
||||
self, generator: PluginPromptGenerator, select_plugins: List[str]
|
||||
) -> PluginPromptGenerator:
|
||||
logger.info(f"load_select_plugin:{select_plugins}")
|
||||
# load select plugin
|
||||
for plugin in self.plugins:
|
||||
if plugin._name in select_plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
generator = plugin.post_prompt(generator)
|
||||
return generator
|
||||
self.tools = AutoGPTPluginToolPack(PLUGINS_DIR)
|
||||
self.tools.preload_resource()
|
||||
|
||||
|
||||
module_plugin = ModulePlugin()
|
||||
|
@@ -9,7 +9,7 @@ from typing import Any
|
||||
from fastapi import UploadFile
|
||||
|
||||
from dbgpt.agent.core.schema import PluginStorageType
|
||||
from dbgpt.agent.plugin.plugins_util import scan_plugins, update_from_git
|
||||
from dbgpt.agent.resource.tool.autogpt.plugins_util import scan_plugins, update_from_git
|
||||
from dbgpt.configs.model_config import PLUGINS_DIR
|
||||
|
||||
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
|
95
dbgpt/serve/agent/resource/datasource.py
Normal file
95
dbgpt/serve/agent/resource/datasource.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
|
||||
from dbgpt.util import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DatasourceDBParameters(DBParameters):
|
||||
"""The DB parameters for the datasource."""
|
||||
|
||||
db_name: str = dataclasses.field(metadata={"help": "DB name"})
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["DatasourceDBParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "db_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "DatasourceDBParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "db_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["db_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class DatasourceResource(RDBMSConnectorResource):
|
||||
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
|
||||
conn = CFG.local_db_manager.get_connector(db_name)
|
||||
super().__init__(name, connector=conn, db_name=db_name, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[DatasourceDBParameters]:
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
results = [db["db_name"] for db in dbs]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynDBParameters(DatasourceDBParameters):
|
||||
db_name: str = dataclasses.field(
|
||||
metadata={"help": "DB name", "valid_values": results}
|
||||
)
|
||||
|
||||
return _DynDBParameters
|
||||
|
||||
def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
"""Return the schema link of the database."""
|
||||
try:
|
||||
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
table_infos = None
|
||||
try:
|
||||
table_infos = client.get_db_summary(
|
||||
db,
|
||||
question,
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"db summary find error!{str(e)}")
|
||||
if not table_infos:
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
table_infos = conn.table_simple_info()
|
||||
|
||||
return table_infos
|
89
dbgpt/serve/agent/resource/knowledge.py
Normal file
89
dbgpt/serve/agent/resource/knowledge.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.knowledge import (
|
||||
RetrieverResource,
|
||||
RetrieverResourceParameters,
|
||||
)
|
||||
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
from dbgpt.util import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class KnowledgeSpaceLoadResourceParameters(RetrieverResourceParameters):
|
||||
space_name: str = dataclasses.field(
|
||||
default=None, metadata={"help": "Knowledge space name"}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["KnowledgeSpaceLoadResourceParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "space_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "KnowledgeSpaceLoadResourceParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "space_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["space_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class KnowledgeSpaceRetrieverResource(RetrieverResource):
|
||||
"""Knowledge Space retriever resource."""
|
||||
|
||||
def __init__(self, name: str, space_name: str):
|
||||
retriever = KnowledgeSpaceRetriever(space_name=space_name)
|
||||
super().__init__(name, retriever=retriever)
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[KnowledgeSpaceLoadResourceParameters]:
|
||||
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
|
||||
knowledge_space_service = KnowledgeService()
|
||||
knowledge_spaces = knowledge_space_service.get_knowledge_space(
|
||||
KnowledgeSpaceRequest()
|
||||
)
|
||||
results = [ks.name for ks in knowledge_spaces]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamicKnowledgeSpaceLoadResourceParameters(
|
||||
KnowledgeSpaceLoadResourceParameters
|
||||
):
|
||||
space_name: str = dataclasses.field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Knowledge space name",
|
||||
"valid_values": results,
|
||||
},
|
||||
)
|
||||
|
||||
return _DynamicKnowledgeSpaceLoadResourceParameters
|
92
dbgpt/serve/agent/resource/plugin.py
Normal file
92
dbgpt/serve/agent/resource/plugin.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.pack import PackResourceParameters
|
||||
from dbgpt.agent.resource.tool.pack import ToolPack
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.serve.agent.hub.controller import ModulePlugin
|
||||
from dbgpt.util.parameter_utils import ParameterDescription
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PluginPackResourceParameters(PackResourceParameters):
|
||||
tool_name: str = dataclasses.field(metadata={"help": "Tool name"})
|
||||
|
||||
@classmethod
|
||||
def _resource_version(cls) -> str:
|
||||
"""Return the resource version."""
|
||||
return "v1"
|
||||
|
||||
@classmethod
|
||||
def to_configurations(
|
||||
cls,
|
||||
parameters: Type["PluginPackResourceParameters"],
|
||||
version: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Convert the parameters to configurations."""
|
||||
conf: List[ParameterDescription] = cast(
|
||||
List[ParameterDescription], super().to_configurations(parameters)
|
||||
)
|
||||
version = version or cls._resource_version()
|
||||
if version != "v1":
|
||||
return conf
|
||||
# Compatible with old version
|
||||
for param in conf:
|
||||
if param.param_name == "tool_name":
|
||||
return param.valid_values or []
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls, data: dict, ignore_extra_fields: bool = True
|
||||
) -> "PluginPackResourceParameters":
|
||||
"""Create a new instance from a dictionary."""
|
||||
copied_data = data.copy()
|
||||
if "tool_name" not in copied_data and "value" in copied_data:
|
||||
copied_data["tool_name"] = copied_data.pop("value")
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
class PluginToolPack(ToolPack):
|
||||
def __init__(self, tool_name: str, **kwargs):
|
||||
kwargs.pop("name")
|
||||
super().__init__([], name="Plugin Tool Pack", **kwargs)
|
||||
# Select tool name
|
||||
self._tool_name = tool_name
|
||||
|
||||
@classmethod
|
||||
def type_alias(cls) -> str:
|
||||
return "tool(autogpt_plugins)"
|
||||
|
||||
@classmethod
|
||||
def resource_parameters_class(cls) -> Type[PluginPackResourceParameters]:
|
||||
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
tool_names = []
|
||||
for name, sub_tool in agent_module.tools._resources.items():
|
||||
tool_names.append(name)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynPluginPackResourceParameters(PluginPackResourceParameters):
|
||||
tool_name: str = dataclasses.field(
|
||||
metadata={"help": "Tool name", "valid_values": tool_names}
|
||||
)
|
||||
|
||||
return _DynPluginPackResourceParameters
|
||||
|
||||
def preload_resource(self):
|
||||
"""Preload the resource."""
|
||||
agent_module: ModulePlugin = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
tool = agent_module.tools._resources.get(self._tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {self._tool_name} not found")
|
||||
self._resources = {tool.name: tool}
|
@@ -1,66 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from dbgpt.agent.resource.resource_db_api import ResourceDbClient
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceLoadClient(ResourceDbClient):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# The executor to submit blocking function
|
||||
self._executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
def get_data_type(self, resource: AgentResource) -> str:
|
||||
conn = CFG.local_db_manager.get_connector(resource.value)
|
||||
return conn.db_type
|
||||
|
||||
async def get_schema_link(
|
||||
self, db: str, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
try:
|
||||
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
||||
except ImportError:
|
||||
raise ValueError("Could not import DBSummaryClient. ")
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
table_infos = None
|
||||
try:
|
||||
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_db_summary,
|
||||
db,
|
||||
question,
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||
)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
if not table_infos:
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor, conn.table_simple_info
|
||||
)
|
||||
|
||||
return table_infos
|
||||
|
||||
async def query_to_df(self, db: str, sql: str):
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
return conn.run_to_df(sql)
|
||||
|
||||
async def query(self, db: str, sql: str):
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
return conn.query_ex(sql)
|
||||
|
||||
async def run_sql(self, db: str, sql: str):
|
||||
conn = CFG.local_db_manager.get_connector(db)
|
||||
return conn.run(sql)
|
@@ -1,35 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from dbgpt.agent.resource.resource_knowledge_api import ResourceKnowledgeClient
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeSpaceLoadClient(ResourceKnowledgeClient):
|
||||
async def get_space_desc(self, space_name) -> str:
|
||||
pass
|
||||
|
||||
async def get_kn(
|
||||
self, space_name: str, question: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
kn_retriver = KnowledgeSpaceRetriever(space_name=space_name)
|
||||
chunks: List[Chunk] = kn_retriver.retrieve(question)
|
||||
return chunks
|
||||
|
||||
async def add_kn(
|
||||
self, space_name: str, kn_name: str, type: str, content: Optional[Any]
|
||||
):
|
||||
kn_retriver = KnowledgeSpaceRetriever(space_name=space_name)
|
||||
|
||||
async def get_data_introduce(
|
||||
self, resource: AgentResource, question: Optional[str] = None
|
||||
) -> Union[str, List[str]]:
|
||||
docs = await self.get_kn(resource.value, question)
|
||||
return "\n".join([doc.content for doc in docs])
|
@@ -1,40 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.plugin.generator import PluginPromptGenerator
|
||||
from dbgpt.agent.resource.resource_plugin_api import ResourcePluginClient
|
||||
from dbgpt.component import ComponentType
|
||||
from dbgpt.serve.agent.hub.controller import ModulePlugin
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginHubLoadClient(ResourcePluginClient):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# The executor to submit blocking function
|
||||
self._executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
async def load_plugin(
|
||||
self, value: str, plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
) -> PluginPromptGenerator:
|
||||
logger.info(f"PluginHubLoadClient load plugin:{value}")
|
||||
if plugin_generator is None:
|
||||
plugin_generator = PluginPromptGenerator()
|
||||
plugin_generator.set_command_registry(CFG.command_registry)
|
||||
|
||||
agent_module = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.PLUGIN_HUB, ModulePlugin
|
||||
)
|
||||
plugin_generator = agent_module.load_select_plugin(
|
||||
plugin_generator, json.dumps(value)
|
||||
)
|
||||
|
||||
return plugin_generator
|
@@ -1,4 +1,5 @@
|
||||
"""Postgres vector store."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -67,10 +68,10 @@ class PGVectorStore(VectorStoreBase):
|
||||
self.collection_name = vector_store_config.name
|
||||
|
||||
self.vector_store_client = PGVector(
|
||||
embedding_function=self.embeddings,
|
||||
embedding_function=self.embeddings, # type: ignore
|
||||
collection_name=self.collection_name,
|
||||
connection_string=self.connection_string,
|
||||
) # mypy: ignore
|
||||
)
|
||||
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
@@ -97,7 +98,7 @@ class PGVectorStore(VectorStoreBase):
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
lc_documents = [Chunk.chunk2langchain(chunk) for chunk in chunks]
|
||||
self.vector_store_client.from_documents(lc_documents)
|
||||
self.vector_store_client.from_documents(lc_documents) # type: ignore
|
||||
return [str(chunk.chunk_id) for chunk in lc_documents]
|
||||
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
|
112
dbgpt/util/cache_utils.py
Normal file
112
dbgpt/util/cache_utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Cache utils.
|
||||
|
||||
Adapted from https://github.com/hephex/asyncache/blob/master/asyncache/__init__.py.
|
||||
It has stopped updating since 2022. So I copied the code here for future reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Callable, MutableMapping, Optional, Protocol, TypeVar
|
||||
|
||||
from cachetools import keys
|
||||
|
||||
_KT = TypeVar("_KT")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class IdentityFunction(Protocol): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Type for a function returning the same type as the one it received.
|
||||
"""
|
||||
|
||||
def __call__(self, __x: _T) -> _T:
|
||||
...
|
||||
|
||||
|
||||
class NullContext:
|
||||
"""A class for noop context managers."""
|
||||
|
||||
def __enter__(self):
|
||||
"""Return ``self`` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Return ``self`` upon entering the runtime context."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Raise any exception triggered within the runtime context."""
|
||||
return None
|
||||
|
||||
|
||||
def cached(
|
||||
cache: Optional[MutableMapping[_KT, Any]],
|
||||
# ignoring the mypy error to be consistent with the type used
|
||||
# in https://github.com/python/typeshed/tree/master/stubs/cachetools
|
||||
key: Callable[..., _KT] = keys.hashkey, # type:ignore
|
||||
lock: Optional["AbstractContextManager[Any]"] = None,
|
||||
) -> IdentityFunction:
|
||||
"""
|
||||
Decorator to wrap a function or a coroutine with a memoizing callable
|
||||
that saves results in a cache.
|
||||
|
||||
When ``lock`` is provided for a standard function, it's expected to
|
||||
implement ``__enter__`` and ``__exit__`` that will be used to lock
|
||||
the cache when gets updated. If it wraps a coroutine, ``lock``
|
||||
must implement ``__aenter__`` and ``__aexit__``.
|
||||
"""
|
||||
lock = lock or NullContext()
|
||||
|
||||
def decorator(func):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
k = key(*args, **kwargs)
|
||||
try:
|
||||
async with lock:
|
||||
return cache[k]
|
||||
|
||||
except KeyError:
|
||||
pass # key not found
|
||||
|
||||
val = await func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
async with lock:
|
||||
cache[k] = val
|
||||
|
||||
except ValueError:
|
||||
pass # val too large
|
||||
|
||||
return val
|
||||
|
||||
else:
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
k = key(*args, **kwargs)
|
||||
try:
|
||||
with lock:
|
||||
return cache[k]
|
||||
|
||||
except KeyError:
|
||||
pass # key not found
|
||||
|
||||
val = func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
with lock:
|
||||
cache[k] = val
|
||||
|
||||
except ValueError:
|
||||
pass # val too large
|
||||
|
||||
return val
|
||||
|
||||
return functools.wraps(func)(wrapper)
|
||||
|
||||
return decorator
|
@@ -1,9 +1,20 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, get_args, get_origin, get_type_hints
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
_UnionGenericAlias,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typeguard import check_type
|
||||
from typing_extensions import Annotated, Doc, _AnnotatedAlias
|
||||
|
||||
|
||||
def _is_typing(obj):
|
||||
@@ -119,3 +130,61 @@ def rearrange_args_by_type(func):
|
||||
return await func(*sorted_args, **sorted_kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
|
||||
def type_to_string(obj: Any, default_type: str = "unknown") -> str:
|
||||
"""Convert a type to a string representation."""
|
||||
type_map = {
|
||||
int: "integer",
|
||||
str: "string",
|
||||
float: "float",
|
||||
bool: "boolean",
|
||||
Any: "any",
|
||||
List: "array",
|
||||
dict: "object",
|
||||
}
|
||||
# Check NoneType
|
||||
if obj is type(None):
|
||||
return "null"
|
||||
|
||||
# Get the origin of the type
|
||||
origin = getattr(obj, "__origin__", None)
|
||||
if origin:
|
||||
if _is_typing(origin) and not isinstance(obj, _UnionGenericAlias):
|
||||
obj = origin
|
||||
origin = origin.__origin__
|
||||
# Handle special cases like List[int]
|
||||
if origin is Union and hasattr(obj, "__args__"):
|
||||
subtypes = ", ".join(
|
||||
type_to_string(t) for t in obj.__args__ if t is not type(None)
|
||||
)
|
||||
# return f"Optional[{subtypes}]"
|
||||
return subtypes
|
||||
elif origin is list or origin is List:
|
||||
subtypes = ", ".join(type_to_string(t) for t in obj.__args__)
|
||||
# return f"List[{subtypes}]"
|
||||
return "array"
|
||||
elif origin in [dict, Dict]:
|
||||
key_type, value_type = (type_to_string(t) for t in obj.__args__)
|
||||
# return f"Dict[{key_type}, {value_type}]"
|
||||
return "object"
|
||||
return type_map.get(origin, default_type)
|
||||
else:
|
||||
if hasattr(obj, "__args__"):
|
||||
subtypes = ", ".join(
|
||||
type_to_string(t) for t in obj.__args__ if t is not type(None)
|
||||
)
|
||||
return subtypes
|
||||
|
||||
return type_map.get(obj, default_type)
|
||||
|
||||
|
||||
def parse_param_description(name: str, obj: Any) -> str:
|
||||
default_type_title = name.replace("_", " ").title()
|
||||
if isinstance(obj, _AnnotatedAlias):
|
||||
metadata = obj.__metadata__
|
||||
docs = [arg for arg in metadata if isinstance(arg, Doc)]
|
||||
doc_str = docs[0].documentation if docs else default_type_title
|
||||
else:
|
||||
doc_str = default_type_title
|
||||
return doc_str
|
||||
|
@@ -104,6 +104,9 @@ class BaseParameters:
|
||||
"""
|
||||
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def _get_dataclass_print_str(obj):
|
||||
class_name = obj.__class__.__name__
|
||||
|
@@ -21,7 +21,6 @@ from dbgpt.agent import (
|
||||
AgentMemory,
|
||||
AutoPlanChatManager,
|
||||
LLMConfig,
|
||||
ResourceLoader,
|
||||
UserProxyAgent,
|
||||
)
|
||||
from dbgpt.agent.expand.code_assistant_agent import CodeAssistantAgent
|
||||
@@ -37,18 +36,15 @@ async def main():
|
||||
|
||||
agent_memory = AgentMemory()
|
||||
|
||||
llm_client = OpenAILLMClient(model_alias="gpt-4")
|
||||
llm_client = OpenAILLMClient(model_alias="gpt-4o")
|
||||
context: AgentContext = AgentContext(
|
||||
conv_id="test456", gpts_app_name="代码分析助手", max_new_tokens=2048
|
||||
)
|
||||
|
||||
resource_loader = ResourceLoader()
|
||||
|
||||
coder = (
|
||||
await CodeAssistantAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(resource_loader)
|
||||
.bind(agent_memory)
|
||||
.build()
|
||||
)
|
||||
|
@@ -15,26 +15,20 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.agent import (
|
||||
AgentContext,
|
||||
AgentMemory,
|
||||
AgentResource,
|
||||
LLMConfig,
|
||||
ResourceLoader,
|
||||
ResourceType,
|
||||
UserProxyAgent,
|
||||
WrappedAWELLayoutManager,
|
||||
)
|
||||
from dbgpt.agent.expand.plugin_assistant_agent import PluginAssistantAgent
|
||||
from dbgpt.agent.expand.resources.search_tool import baidu_search
|
||||
from dbgpt.agent.expand.summary_assistant_agent import SummaryAssistantAgent
|
||||
from dbgpt.agent.resource import PluginFileLoadClient
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.agent.expand.tool_assistant_agent import ToolAssistantAgent
|
||||
from dbgpt.agent.resource import ToolPack
|
||||
from dbgpt.util.tracer import initialize_tracer
|
||||
|
||||
test_plugin_dir = os.path.join(ROOT_PATH, "examples/test_files/plugins")
|
||||
|
||||
initialize_tracer("/tmp/agent_trace.jsonl", create_system_app=True)
|
||||
|
||||
|
||||
@@ -45,23 +39,14 @@ async def main():
|
||||
context: AgentContext = AgentContext(conv_id="test456", gpts_app_name="信息析助手")
|
||||
|
||||
agent_memory = AgentMemory()
|
||||
resource_loader = ResourceLoader()
|
||||
plugin_file_loader = PluginFileLoadClient()
|
||||
resource_loader.register_resource_api(plugin_file_loader)
|
||||
|
||||
plugin_resource = AgentResource(
|
||||
type=ResourceType.Plugin,
|
||||
name="test",
|
||||
value=test_plugin_dir,
|
||||
)
|
||||
|
||||
tools = ToolPack([baidu_search])
|
||||
tool_engineer = (
|
||||
await PluginAssistantAgent()
|
||||
await ToolAssistantAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(agent_memory)
|
||||
.bind([plugin_resource])
|
||||
.bind(resource_loader)
|
||||
.bind(tools)
|
||||
.build()
|
||||
)
|
||||
summarizer = (
|
||||
@@ -86,7 +71,7 @@ async def main():
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=manager,
|
||||
reviewer=user_proxy,
|
||||
message="查询成都今天天气",
|
||||
message="查询北京今天天气",
|
||||
# message="查询今天的最新热点财经新闻",
|
||||
# message="Find papers on gpt-4 in the past three weeks on arxiv, and organize their titles, authors, and links into a markdown table",
|
||||
# message="find papers on LLM applications from arxiv in the last month, create a markdown table of different domains.",
|
||||
|
80
examples/agents/custom_tool_agent_example.py
Normal file
80
examples/agents/custom_tool_agent_example.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing_extensions import Annotated, Doc
|
||||
|
||||
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
|
||||
from dbgpt.agent.expand.tool_assistant_agent import ToolAssistantAgent
|
||||
from dbgpt.agent.resource import ToolPack, tool
|
||||
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
def simple_calculator(first_number: int, second_number: int, operator: str) -> float:
|
||||
"""Simple calculator tool. Just support +, -, *, /."""
|
||||
if operator == "+":
|
||||
return first_number + second_number
|
||||
elif operator == "-":
|
||||
return first_number - second_number
|
||||
elif operator == "*":
|
||||
return first_number * second_number
|
||||
elif operator == "/":
|
||||
return first_number / second_number
|
||||
else:
|
||||
raise ValueError(f"Invalid operator: {operator}")
|
||||
|
||||
|
||||
@tool
|
||||
def count_directory_files(path: Annotated[str, Doc("The directory path")]) -> int:
|
||||
"""Count the number of files in a directory."""
|
||||
if not os.path.isdir(path):
|
||||
raise ValueError(f"Invalid directory path: {path}")
|
||||
return len(os.listdir(path))
|
||||
|
||||
|
||||
async def main():
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
|
||||
llm_client = OpenAILLMClient(model_alias="gpt-3.5-turbo")
|
||||
context: AgentContext = AgentContext(conv_id="test456")
|
||||
|
||||
agent_memory = AgentMemory()
|
||||
|
||||
tools = ToolPack([simple_calculator, count_directory_files])
|
||||
|
||||
user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build()
|
||||
|
||||
tool_engineer = (
|
||||
await ToolAssistantAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(agent_memory)
|
||||
.bind(tools)
|
||||
.build()
|
||||
)
|
||||
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=tool_engineer,
|
||||
reviewer=user_proxy,
|
||||
message="Calculate the product of 10 and 99",
|
||||
)
|
||||
|
||||
await user_proxy.initiate_chat(
|
||||
recipient=tool_engineer,
|
||||
reviewer=user_proxy,
|
||||
message="Count the number of files in /tmp",
|
||||
)
|
||||
|
||||
# dbgpt-vis message infos
|
||||
print(await agent_memory.gpts_memory.one_chat_completions("test456"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@@ -17,17 +17,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.agent import (
|
||||
AgentContext,
|
||||
AgentMemory,
|
||||
AgentResource,
|
||||
LLMConfig,
|
||||
ResourceLoader,
|
||||
ResourceType,
|
||||
UserProxyAgent,
|
||||
)
|
||||
from dbgpt.agent.expand.plugin_assistant_agent import PluginAssistantAgent
|
||||
from dbgpt.agent.resource import PluginFileLoadClient
|
||||
from dbgpt.agent import AgentContext, AgentMemory, LLMConfig, UserProxyAgent
|
||||
from dbgpt.agent.expand.tool_assistant_agent import ToolAssistantAgent
|
||||
from dbgpt.agent.resource import AutoGPTPluginToolPack
|
||||
|
||||
current_dir = os.getcwd()
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
@@ -42,25 +34,16 @@ async def main():
|
||||
|
||||
agent_memory = AgentMemory()
|
||||
|
||||
plugin_resource = AgentResource(
|
||||
type=ResourceType.Plugin,
|
||||
name="test",
|
||||
value=test_plugin_dir,
|
||||
)
|
||||
|
||||
resource_loader = ResourceLoader()
|
||||
plugin_file_loader = PluginFileLoadClient()
|
||||
resource_loader.register_resource_api(plugin_file_loader)
|
||||
tools = AutoGPTPluginToolPack(test_plugin_dir)
|
||||
|
||||
user_proxy = await UserProxyAgent().bind(agent_memory).bind(context).build()
|
||||
|
||||
tool_engineer = (
|
||||
await PluginAssistantAgent()
|
||||
await ToolAssistantAgent()
|
||||
.bind(context)
|
||||
.bind(LLMConfig(llm_client=llm_client))
|
||||
.bind(agent_memory)
|
||||
.bind([plugin_resource])
|
||||
.bind(resource_loader)
|
||||
.bind(tools)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user