mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 13:57:46 +00:00
feat(agent): Support MCP authentication
This commit is contained in:
parent
0b6633703d
commit
2e0e140b51
@ -139,7 +139,7 @@ class ConversableAgent(Role, Agent):
|
|||||||
async def preload_resource(self) -> None:
|
async def preload_resource(self) -> None:
|
||||||
"""Preload resources before agent initialization."""
|
"""Preload resources before agent initialization."""
|
||||||
if self.resource:
|
if self.resource:
|
||||||
await self.blocking_func_to_async(self.resource.preload_resource)
|
await self.resource.preload_resource()
|
||||||
|
|
||||||
async def build(self, is_retry_chat: bool = False) -> "ConversableAgent":
|
async def build(self, is_retry_chat: bool = False) -> "ConversableAgent":
|
||||||
"""Build the agent."""
|
"""Build the agent."""
|
||||||
|
@ -166,6 +166,16 @@ class ReActAction(ToolAction):
|
|||||||
name = parsed_step.action
|
name = parsed_step.action
|
||||||
action_input = parsed_step.action_input
|
action_input = parsed_step.action_input
|
||||||
action_input_str = action_input
|
action_input_str = action_input
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
terminal_content = str(action_input_str if action_input_str else ai_message)
|
||||||
|
return ActionOutput(
|
||||||
|
is_exe_success=True,
|
||||||
|
content=terminal_content,
|
||||||
|
observations=terminal_content,
|
||||||
|
terminate=True,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try to parse the action input to dict
|
# Try to parse the action input to dict
|
||||||
if action_input and isinstance(action_input, str):
|
if action_input and isinstance(action_input, str):
|
||||||
|
@ -13,11 +13,11 @@ from dbgpt.agent import (
|
|||||||
ResourceType,
|
ResourceType,
|
||||||
)
|
)
|
||||||
from dbgpt.agent.core.role import AgentRunMode
|
from dbgpt.agent.core.role import AgentRunMode
|
||||||
from dbgpt.agent.resource import BaseTool, ToolPack
|
from dbgpt.agent.resource import BaseTool, ResourcePack, ToolPack
|
||||||
from dbgpt.agent.util.react_parser import ReActOutputParser
|
from dbgpt.agent.util.react_parser import ReActOutputParser
|
||||||
from dbgpt.util.configure import DynConfig
|
from dbgpt.util.configure import DynConfig
|
||||||
|
|
||||||
from .actions.react_action import ReActAction
|
from .actions.react_action import ReActAction, Terminate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ class ReActAgent(ConversableAgent):
|
|||||||
"""Init indicator AssistantAgent."""
|
"""Init indicator AssistantAgent."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self._init_actions([ReActAction])
|
self._init_actions([ReActAction, Terminate])
|
||||||
|
|
||||||
async def _a_init_reply_message(
|
async def _a_init_reply_message(
|
||||||
self,
|
self,
|
||||||
@ -150,6 +150,36 @@ class ReActAgent(ConversableAgent):
|
|||||||
}
|
}
|
||||||
return reply_message
|
return reply_message
|
||||||
|
|
||||||
|
async def preload_resource(self) -> None:
|
||||||
|
await super().preload_resource()
|
||||||
|
self._check_and_add_terminate()
|
||||||
|
|
||||||
|
def _check_and_add_terminate(self):
|
||||||
|
if not self.resource:
|
||||||
|
return
|
||||||
|
_is_has_terminal = False
|
||||||
|
|
||||||
|
def _has_terminal(r: Resource):
|
||||||
|
nonlocal _is_has_terminal
|
||||||
|
if r.type() == ResourceType.Tool and isinstance(r, Terminate):
|
||||||
|
_is_has_terminal = True
|
||||||
|
return r
|
||||||
|
|
||||||
|
_has_add_terminal = False
|
||||||
|
|
||||||
|
def _add_terminate(r: Resource):
|
||||||
|
nonlocal _has_add_terminal
|
||||||
|
if not _has_add_terminal and isinstance(r, ResourcePack):
|
||||||
|
terminal = Terminate()
|
||||||
|
r._resources[terminal.name] = terminal
|
||||||
|
_has_add_terminal = True
|
||||||
|
return r
|
||||||
|
|
||||||
|
self.resource.apply(apply_func=_has_terminal)
|
||||||
|
if not _is_has_terminal:
|
||||||
|
# Add terminal action to the resource
|
||||||
|
self.resource.apply(apply_pack_func=_add_terminate)
|
||||||
|
|
||||||
async def load_resource(self, question: str, is_retry_chat: bool = False):
|
async def load_resource(self, question: str, is_retry_chat: bool = False):
|
||||||
"""Load agent bind resource."""
|
"""Load agent bind resource."""
|
||||||
if self.resource:
|
if self.resource:
|
||||||
|
@ -272,7 +272,12 @@ class Resource(ABC, Generic[P]):
|
|||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
apply_func: Callable[["Resource"], Union["Resource", List["Resource"], None]],
|
apply_func: Optional[
|
||||||
|
Callable[["Resource"], Union["Resource", List["Resource"], None]]
|
||||||
|
] = None,
|
||||||
|
apply_pack_func: Optional[
|
||||||
|
Callable[["Resource"], Union["Resource", None]]
|
||||||
|
] = None,
|
||||||
) -> Union["Resource", None]:
|
) -> Union["Resource", None]:
|
||||||
"""Apply the function to the resource."""
|
"""Apply the function to the resource."""
|
||||||
return self
|
return self
|
||||||
|
@ -125,16 +125,27 @@ class ResourcePack(Resource[PackResourceParameters]):
|
|||||||
return list(self._resources.values())
|
return list(self._resources.values())
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self, apply_func: Callable[[Resource], Union[Resource, List[Resource], None]]
|
self,
|
||||||
|
apply_func: Optional[
|
||||||
|
Callable[[Resource], Union[Resource, List[Resource], None]]
|
||||||
|
] = None,
|
||||||
|
apply_pack_func: Optional[
|
||||||
|
Callable[["Resource"], Union["Resource", None]]
|
||||||
|
] = None,
|
||||||
) -> Union[Resource, None]:
|
) -> Union[Resource, None]:
|
||||||
"""Apply the function to the resource."""
|
"""Apply the function to the resource."""
|
||||||
if not self.is_pack:
|
if not self.is_pack:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
if not apply_func and not apply_pack_func:
|
||||||
|
raise ValueError("No function provided to apply to the resource pack.")
|
||||||
|
|
||||||
def _apply_func_to_resource(
|
def _apply_func_to_resource(
|
||||||
resource: Resource,
|
resource: Resource,
|
||||||
) -> Union[Resource, List[Resource], None]:
|
) -> Union[Resource, List[Resource], None]:
|
||||||
if resource.is_pack:
|
if resource.is_pack:
|
||||||
|
if apply_pack_func is not None:
|
||||||
|
return apply_pack_func(resource)
|
||||||
resources = []
|
resources = []
|
||||||
resource_copy = cast(ResourcePack, copy.copy(resource))
|
resource_copy = cast(ResourcePack, copy.copy(resource))
|
||||||
for resource_copy in resource_copy.sub_resources:
|
for resource_copy in resource_copy.sub_resources:
|
||||||
@ -149,7 +160,10 @@ class ResourcePack(Resource[PackResourceParameters]):
|
|||||||
resource.name: resource for resource in resources
|
resource.name: resource for resource in resources
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return apply_func(resource)
|
if apply_func is not None:
|
||||||
|
return apply_func(resource)
|
||||||
|
else:
|
||||||
|
return resource
|
||||||
|
|
||||||
new_resource = _apply_func_to_resource(self)
|
new_resource = _apply_func_to_resource(self)
|
||||||
resource_copy = cast(ResourcePack, copy.copy(self))
|
resource_copy = cast(ResourcePack, copy.copy(self))
|
||||||
|
@ -278,12 +278,48 @@ class AutoGPTPluginToolPack(ToolPack):
|
|||||||
|
|
||||||
|
|
||||||
class MCPToolPack(ToolPack):
|
class MCPToolPack(ToolPack):
|
||||||
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
"""MCP tool pack class.
|
||||||
|
|
||||||
|
Wrap the MCP SSE server as a tool pack.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
tools = MCPToolPack("http://127.0.0.1:8000/sse")
|
||||||
|
|
||||||
|
If you want to pass the token to the server, you can use the headers parameter:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
tools = MCPToolPack(
|
||||||
|
"http://127.0.0.1:8000/sse"
|
||||||
|
default_headers={"Authorization": "Bearer your_token"}
|
||||||
|
)
|
||||||
|
# Set the default headers for ech server
|
||||||
|
tools2 = MCPToolPack(
|
||||||
|
"http://127.0.0.1:8000/sse"
|
||||||
|
headers = {
|
||||||
|
"http://127.0.0.1:8000/sse": {
|
||||||
|
"Authorization": "Bearer your_token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_servers: Union[str, List[str]],
|
||||||
|
headers: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
|
default_headers: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""Create an Auto-GPT plugin tool pack."""
|
"""Create an Auto-GPT plugin tool pack."""
|
||||||
super().__init__([], **kwargs)
|
super().__init__([], **kwargs)
|
||||||
self._mcp_servers = mcp_servers
|
self._mcp_servers = mcp_servers
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
self.tool_server_map = {}
|
self.tool_server_map = {}
|
||||||
|
self._default_headers = default_headers or {}
|
||||||
|
self._headers_map = headers or {}
|
||||||
|
self.server_headers_map = {}
|
||||||
|
|
||||||
def switch_mcp_input_schema(self, input_schema: dict):
|
def switch_mcp_input_schema(self, input_schema: dict):
|
||||||
args = {}
|
args = {}
|
||||||
@ -324,7 +360,10 @@ class MCPToolPack(ToolPack):
|
|||||||
server_list = self._mcp_servers.split(";")
|
server_list = self._mcp_servers.split(";")
|
||||||
|
|
||||||
for server in server_list:
|
for server in server_list:
|
||||||
async with sse_client(url=server) as (read, write):
|
server_headers = self._headers_map.get(server, self._default_headers)
|
||||||
|
self.server_headers_map[server] = server_headers
|
||||||
|
|
||||||
|
async with sse_client(url=server, headers=server_headers) as (read, write):
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
# Initialize the connection
|
# Initialize the connection
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
@ -338,7 +377,10 @@ class MCPToolPack(ToolPack):
|
|||||||
tool_name=tool_name, server=server, **kwargs
|
tool_name=tool_name, server=server, **kwargs
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
async with sse_client(url=server) as (read, write):
|
headers_to_use = self.server_headers_map.get(server, {})
|
||||||
|
async with sse_client(
|
||||||
|
url=server, headers=headers_to_use
|
||||||
|
) as (read, write):
|
||||||
async with ClientSession(read, write) as session:
|
async with ClientSession(read, write) as session:
|
||||||
# Initialize the connection
|
# Initialize the connection
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
@ -49,7 +49,27 @@ class MCPPackResourceParameters(PackResourceParameters):
|
|||||||
|
|
||||||
class MCPSSEToolPack(MCPToolPack):
|
class MCPSSEToolPack(MCPToolPack):
|
||||||
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
def __init__(self, mcp_servers: Union[str, List[str]], **kwargs):
|
||||||
super().__init__(mcp_servers=mcp_servers, **kwargs)
|
"""Initialize the MCPSSEToolPack with the given MCP servers."""
|
||||||
|
headers = {}
|
||||||
|
if "token" in kwargs and kwargs["token"]:
|
||||||
|
# token is not supported in sse mode
|
||||||
|
servers = (
|
||||||
|
mcp_servers.split(";") if isinstance(mcp_servers, str) else mcp_servers
|
||||||
|
)
|
||||||
|
tokens = (
|
||||||
|
kwargs["token"].split(";")
|
||||||
|
if isinstance(kwargs["token"], str)
|
||||||
|
else kwargs["token"]
|
||||||
|
)
|
||||||
|
if len(servers) == len(tokens):
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
headers[servers[i]] = {"Authorization": f"Bearer {token}"}
|
||||||
|
else:
|
||||||
|
token = tokens[0]
|
||||||
|
for server in servers:
|
||||||
|
headers[server] = {"Authorization": f"Bearer {token}"}
|
||||||
|
kwargs.pop("token")
|
||||||
|
super().__init__(mcp_servers=mcp_servers, headers=headers, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def type_alias(cls) -> str:
|
def type_alias(cls) -> str:
|
||||||
@ -64,7 +84,17 @@ class MCPSSEToolPack(MCPToolPack):
|
|||||||
mcp_servers: str = dataclasses.field(
|
mcp_servers: str = dataclasses.field(
|
||||||
default="http://127.0.0.1:8000/sse",
|
default="http://127.0.0.1:8000/sse",
|
||||||
metadata={
|
metadata={
|
||||||
"help": _("MCP SSE Server URL, split by ':'"),
|
"help": _("MCP SSE Server URL, split by ';'"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
token: Optional[str] = dataclasses.field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
'MCP SSE Server token, split by ";", It will be '
|
||||||
|
'added to the header({"Authorization": "Bearer your_token"}'
|
||||||
|
),
|
||||||
|
"tags": "privacy",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user