mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-21 03:19:19 +00:00
feat(agent): Support MCP authentication
This commit is contained in:
parent
0b6633703d
commit
2e0e140b51
@ -139,7 +139,7 @@ class ConversableAgent(Role, Agent):
|
||||
async def preload_resource(self) -> None:
|
||||
"""Preload resources before agent initialization."""
|
||||
if self.resource:
|
||||
await self.blocking_func_to_async(self.resource.preload_resource)
|
||||
await self.resource.preload_resource()
|
||||
|
||||
async def build(self, is_retry_chat: bool = False) -> "ConversableAgent":
|
||||
"""Build the agent."""
|
||||
|
@ -166,6 +166,16 @@ class ReActAction(ToolAction):
|
||||
name = parsed_step.action
|
||||
action_input = parsed_step.action_input
|
||||
action_input_str = action_input
|
||||
|
||||
if not name:
|
||||
terminal_content = str(action_input_str if action_input_str else ai_message)
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=terminal_content,
|
||||
observations=terminal_content,
|
||||
terminate=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to parse the action input to dict
|
||||
if action_input and isinstance(action_input, str):
|
||||
|
@ -13,11 +13,11 @@ from dbgpt.agent import (
|
||||
ResourceType,
|
||||
)
|
||||
from dbgpt.agent.core.role import AgentRunMode
|
||||
from dbgpt.agent.resource import BaseTool, ToolPack
|
||||
from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack
|
||||
from dbgpt.agent.util.react_parser import ReActOutputParser
|
||||
from dbgpt.util.configure import DynConfig
|
||||
|
||||
from .actions.react_action import ReActAction
|
||||
from .actions.react_action import ReActAction, Terminate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -113,7 +113,7 @@ class ReActAgent(ConversableAgent):
|
||||
"""Init indicator AssistantAgent."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([ReActAction])
|
||||
self._init_actions([ReActAction, Terminate])
|
||||
|
||||
async def _a_init_reply_message(
|
||||
self,
|
||||
@ -150,6 +150,36 @@ class ReActAgent(ConversableAgent):
|
||||
}
|
||||
return reply_message
|
||||
|
||||
async def preload_resource(self) -> None:
|
||||
await super().preload_resource()
|
||||
self._check_and_add_terminate()
|
||||
|
||||
def _check_and_add_terminate(self):
|
||||
if not self.resource:
|
||||
return
|
||||
_is_has_terminal = False
|
||||
|
||||
def _has_terminal(r: Resource):
|
||||
nonlocal _is_has_terminal
|
||||
if r.type() == ResourceType.Tool and isinstance(r, Terminate):
|
||||
_is_has_terminal = True
|
||||
return r
|
||||
|
||||
_has_add_terminal = False
|
||||
|
||||
def _add_terminate(r: Resource):
|
||||
nonlocal _has_add_terminal
|
||||
if not _has_add_terminal and isinstance(r, ResourcePack):
|
||||
terminal = Terminate()
|
||||
r._resources[terminal.name] = terminal
|
||||
_has_add_terminal = True
|
||||
return r
|
||||
|
||||
self.resource.apply(apply_func=_has_terminal)
|
||||
if not _is_has_terminal:
|
||||
# Add terminal action to the resource
|
||||
self.resource.apply(apply_pack_func=_add_terminate)
|
||||
|
||||
async def load_resource(self, question: str, is_retry_chat: bool = False):
|
||||
"""Load agent bind resource."""
|
||||
if self.resource:
|
||||
|
@ -272,7 +272,12 @@ class Resource(ABC, Generic[P]):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
apply_func: Callable[["Resource"], Union["Resource", List["Resource"], None]],
|
||||
apply_func: Optional[
|
||||
Callable[["Resource"], Union["Resource", List["Resource"], None]]
|
||||
] = None,
|
||||
apply_pack_func: Optional[
|
||||
Callable[["Resource"], Union["Resource", None]]
|
||||
] = None,
|
||||
) -> Union["Resource", None]:
|
||||
"""Apply the function to the resource."""
|
||||
return self
|
||||
|
@ -125,16 +125,27 @@ class ResourcePack(Resource[PackResourceParameters]):
|
||||
return list(self._resources.values())
|
||||
|
||||
def apply(
|
||||
self, apply_func: Callable[[Resource], Union[Resource, List[Resource], None]]
|
||||
self,
|
||||
apply_func: Optional[
|
||||
Callable[[Resource], Union[Resource, List[Resource], None]]
|
||||
] = None,
|
||||
apply_pack_func: Optional[
|
||||
Callable[["Resource"], Union["Resource", None]]
|
||||
] = None,
|
||||
) -> Union[Resource, None]:
|
||||
"""Apply the function to the resource."""
|
||||
if not self.is_pack:
|
||||
return self
|
||||
|
||||
if not apply_func and not apply_pack_func:
|
||||
raise ValueError("No function provided to apply to the resource pack.")
|
||||
|
||||
def _apply_func_to_resource(
|
||||
resource: Resource,
|
||||
) -> Union[Resource, List[Resource], None]:
|
||||
if resource.is_pack:
|
||||
if apply_pack_func is not None:
|
||||
return apply_pack_func(resource)
|
||||
resources = []
|
||||
resource_copy = cast(ResourcePack, copy.copy(resource))
|
||||
for resource_copy in resource_copy.sub_resources:
|
||||
@ -149,7 +160,10 @@ class ResourcePack(Resource[PackResourceParameters]):
|
||||
resource.name: resource for resource in resources
|
||||
}
|
||||
else:
|
||||
return apply_func(resource)
|
||||
if apply_func is not None:
|
||||
return apply_func(resource)
|
||||
else:
|
||||
return resource
|
||||
|
||||
new_resource = _apply_func_to_resource(self)
|
||||
resource_copy = cast(ResourcePack, copy.copy(self))
|
||||
|
@ -278,12 +278,48 @@ class AutoGPTPluginToolPack(ToolPack):
|
||||
|
||||
|
||||
class MCPToolPack(ToolPack):
|
||||
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
||||
"""MCP tool pack class.
|
||||
|
||||
Wrap the MCP SSE server as a tool pack.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
tools = MCPToolPack("http://127.0.0.1:8000/sse")
|
||||
|
||||
If you want to pass the token to the server, you can use the headers parameter:
|
||||
.. code-block:: python
|
||||
|
||||
tools = MCPToolPack(
|
||||
"http://127.0.0.1:8000/sse"
|
||||
default_headers={"Authorization": "Bearer your_token"}
|
||||
)
|
||||
# Set the default headers for ech server
|
||||
tools2 = MCPToolPack(
|
||||
"http://127.0.0.1:8000/sse"
|
||||
headers = {
|
||||
"http://127.0.0.1:8000/sse": {
|
||||
"Authorization": "Bearer your_token"
|
||||
}
|
||||
}
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_servers: Union[str, List[str]],
|
||||
headers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
default_headers: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create an Auto-GPT plugin tool pack."""
|
||||
super().__init__([], **kwargs)
|
||||
self._mcp_servers = mcp_servers
|
||||
self._loaded = False
|
||||
self.tool_server_map = {}
|
||||
self._default_headers = default_headers or {}
|
||||
self._headers_map = headers or {}
|
||||
self.server_headers_map = {}
|
||||
|
||||
def switch_mcp_input_schema(self, input_schema: dict):
|
||||
args = {}
|
||||
@ -324,7 +360,10 @@ class MCPToolPack(ToolPack):
|
||||
server_list = self._mcp_servers.split(";")
|
||||
|
||||
for server in server_list:
|
||||
async with sse_client(url=server) as (read, write):
|
||||
server_headers = self._headers_map.get(server, self._default_headers)
|
||||
self.server_headers_map[server] = server_headers
|
||||
|
||||
async with sse_client(url=server, headers=server_headers) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
@ -338,7 +377,10 @@ class MCPToolPack(ToolPack):
|
||||
tool_name=tool_name, server=server, **kwargs
|
||||
):
|
||||
try:
|
||||
async with sse_client(url=server) as (read, write):
|
||||
headers_to_use = self.server_headers_map.get(server, {})
|
||||
async with sse_client(
|
||||
url=server, headers=headers_to_use
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
|
@ -49,7 +49,27 @@ class MCPPackResourceParameters(PackResourceParameters):
|
||||
|
||||
class MCPSSEToolPack(MCPToolPack):
|
||||
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
||||
super().__init__(mcp_servers=mcp_servers, **kwargs)
|
||||
"""Initialize the MCPSSEToolPack with the given MCP servers."""
|
||||
headers = {}
|
||||
if "token" in kwargs and kwargs["token"]:
|
||||
# token is not supported in sse mode
|
||||
servers = (
|
||||
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
|
||||
)
|
||||
tokens = (
|
||||
kwargs["token"].split(";")
|
||||
if isinstance(kwargs["token"], str)
|
||||
else kwargs["token"]
|
||||
)
|
||||
if len(servers) == len(tokens):
|
||||
for i, token in enumerate(tokens):
|
||||
headers[servers[i]] = {"Authorization": f"Bearer {token}"}
|
||||
else:
|
||||
token = tokens[0]
|
||||
for server in servers:
|
||||
headers[server] = {"Authorization": f"Bearer {token}"}
|
||||
kwargs.pop("token")
|
||||
super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def type_alias(cls) -> str:
|
||||
@ -64,7 +84,17 @@ class MCPSSEToolPack(MCPToolPack):
|
||||
mcp_servers: str = dataclasses.field(
|
||||
default="http://127.0.0.1:8000/sse",
|
||||
metadata={
|
||||
"help": _("MCP SSE Server URL, split by ':'"),
|
||||
"help": _("MCP SSE Server URL, split by ';'"),
|
||||
},
|
||||
)
|
||||
token: Optional[str] = dataclasses.field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
'MCP SSE Server token, split by ";", It will be '
|
||||
'added to the header({"Authorization": "Bearer your_token"}'
|
||||
),
|
||||
"tags": "privacy",
|
||||
},
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user