diff --git a/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py b/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py index 5b7f3fcbd52..77d9fedf6fb 100644 --- a/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/openapi/toolkit.py @@ -41,15 +41,32 @@ class RequestsToolkit(BaseToolkit): """ requests_wrapper: TextRequestsWrapper + allow_dangerous_requests: bool = False + """Allow dangerous requests. See documentation for details.""" def get_tools(self) -> List[BaseTool]: """Return a list of tools.""" return [ - RequestsGetTool(requests_wrapper=self.requests_wrapper), - RequestsPostTool(requests_wrapper=self.requests_wrapper), - RequestsPatchTool(requests_wrapper=self.requests_wrapper), - RequestsPutTool(requests_wrapper=self.requests_wrapper), - RequestsDeleteTool(requests_wrapper=self.requests_wrapper), + RequestsGetTool( + requests_wrapper=self.requests_wrapper, + allow_dangerous_requests=self.allow_dangerous_requests, + ), + RequestsPostTool( + requests_wrapper=self.requests_wrapper, + allow_dangerous_requests=self.allow_dangerous_requests, + ), + RequestsPatchTool( + requests_wrapper=self.requests_wrapper, + allow_dangerous_requests=self.allow_dangerous_requests, + ), + RequestsPutTool( + requests_wrapper=self.requests_wrapper, + allow_dangerous_requests=self.allow_dangerous_requests, + ), + RequestsDeleteTool( + requests_wrapper=self.requests_wrapper, + allow_dangerous_requests=self.allow_dangerous_requests, + ), ] @@ -66,6 +83,8 @@ class OpenAPIToolkit(BaseToolkit): json_agent: Any requests_wrapper: TextRequestsWrapper + allow_dangerous_requests: bool = False + """Allow dangerous requests. See documentation for details.""" def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" @@ -74,7 +93,10 @@ class OpenAPIToolkit(BaseToolkit): func=self.json_agent.run, description=DESCRIPTION, ) - request_toolkit = RequestsToolkit(requests_wrapper=self.requests_wrapper) + request_toolkit = RequestsToolkit( + requests_wrapper=self.requests_wrapper, + allow_dangerous_requests=self.allow_dangerous_requests, + ) return [*request_toolkit.get_tools(), json_agent_tool] @classmethod @@ -83,8 +105,13 @@ class OpenAPIToolkit(BaseToolkit): llm: BaseLanguageModel, json_spec: JsonSpec, requests_wrapper: TextRequestsWrapper, + allow_dangerous_requests: bool = False, **kwargs: Any, ) -> OpenAPIToolkit: """Create json agent from llm, then initialize.""" json_agent = create_json_agent(llm, JsonToolkit(spec=json_spec), **kwargs) - return cls(json_agent=json_agent, requests_wrapper=requests_wrapper) + return cls( + json_agent=json_agent, + requests_wrapper=requests_wrapper, + allow_dangerous_requests=allow_dangerous_requests, + ) diff --git a/libs/community/langchain_community/tools/requests/tool.py b/libs/community/langchain_community/tools/requests/tool.py index 17985d5aa3a..ea725b782f3 100644 --- a/libs/community/langchain_community/tools/requests/tool.py +++ b/libs/community/langchain_community/tools/requests/tool.py @@ -28,6 +28,23 @@ class BaseRequestsTool(BaseModel): requests_wrapper: GenericRequestsWrapper + allow_dangerous_requests: bool = False + + def __init__(self, **kwargs: Any): + """Initialize the tool.""" + if not kwargs.get("allow_dangerous_requests", False): + raise ValueError( + "You must set allow_dangerous_requests to True to use this tool. " + "Request scan be dangerous and can lead to security vulnerabilities. " + "For example, users can ask a server to make a request to an internal" + "server. It's recommended to use requests through a proxy server " + "and avoid accepting inputs from untrusted sources without proper " + "sandboxing." + "Please see: https://python.langchain.com/docs/security for " + "further security information." + ) + super().__init__(**kwargs) + class RequestsGetTool(BaseRequestsTool, BaseTool): """Tool for making a GET request to an API endpoint.""" diff --git a/libs/community/tests/unit_tests/tools/requests/test_tool.py b/libs/community/tests/unit_tests/tools/requests/test_tool.py index b2d53dcc666..186d04393bc 100644 --- a/libs/community/tests/unit_tests/tools/requests/test_tool.py +++ b/libs/community/tests/unit_tests/tools/requests/test_tool.py @@ -72,34 +72,44 @@ def test_parse_input() -> None: def test_requests_get_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: - tool = RequestsGetTool(requests_wrapper=mock_requests_wrapper) + tool = RequestsGetTool( + requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True + ) assert tool.run("https://example.com") == "get_response" assert asyncio.run(tool.arun("https://example.com")) == "aget_response" def test_requests_post_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: - tool = RequestsPostTool(requests_wrapper=mock_requests_wrapper) + tool = RequestsPostTool( + requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True + ) input_text = '{"url": "https://example.com", "data": {"key": "value"}}' assert tool.run(input_text) == "post {'key': 'value'}" assert asyncio.run(tool.arun(input_text)) == "apost {'key': 'value'}" def test_requests_patch_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: - tool = RequestsPatchTool(requests_wrapper=mock_requests_wrapper) + tool = RequestsPatchTool( + requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True + ) input_text = '{"url": "https://example.com", "data": {"key": "value"}}' assert tool.run(input_text) == "patch {'key': 'value'}" assert asyncio.run(tool.arun(input_text)) == "apatch {'key': 'value'}" def test_requests_put_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: - tool = RequestsPutTool(requests_wrapper=mock_requests_wrapper) + tool = RequestsPutTool( + requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True + ) input_text = '{"url": "https://example.com", "data": {"key": "value"}}' assert tool.run(input_text) == "put {'key': 'value'}" assert asyncio.run(tool.arun(input_text)) == "aput {'key': 'value'}" def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> None: - tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper) + tool = RequestsDeleteTool( + requests_wrapper=mock_requests_wrapper, allow_dangerous_requests=True + ) assert tool.run("https://example.com") == "delete_response" assert asyncio.run(tool.arun("https://example.com")) == "adelete_response" @@ -154,7 +164,9 @@ def mock_json_requests_wrapper() -> JsonRequestsWrapper: def test_requests_get_tool_json( mock_json_requests_wrapper: JsonRequestsWrapper, ) -> None: - tool = RequestsGetTool(requests_wrapper=mock_json_requests_wrapper) + tool = RequestsGetTool( + requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True + ) assert tool.run("https://example.com") == {"response": "get_response"} assert asyncio.run(tool.arun("https://example.com")) == { "response": "aget_response" @@ -164,7 +176,9 @@ def test_requests_get_tool_json( def test_requests_post_tool_json( mock_json_requests_wrapper: JsonRequestsWrapper, ) -> None: - tool = RequestsPostTool(requests_wrapper=mock_json_requests_wrapper) + tool = RequestsPostTool( + requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True + ) input_text = '{"url": "https://example.com", "data": {"key": "value"}}' assert tool.run(input_text) == {"response": 'post {"key": "value"}'} assert asyncio.run(tool.arun(input_text)) == {"response": 'apost {"key": "value"}'} @@ -173,7 +187,9 @@ def test_requests_post_tool_json( def test_requests_patch_tool_json( mock_json_requests_wrapper: JsonRequestsWrapper, ) -> None: - tool = RequestsPatchTool(requests_wrapper=mock_json_requests_wrapper) + tool = RequestsPatchTool( + requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True + ) input_text = '{"url": "https://example.com", "data": {"key": "value"}}' assert tool.run(input_text) == {"response": 'patch {"key": "value"}'} assert asyncio.run(tool.arun(input_text)) == {"response": 'apatch {"key": "value"}'} @@ -182,7 +198,9 @@ def test_requests_patch_tool_json( def test_requests_put_tool_json( mock_json_requests_wrapper: JsonRequestsWrapper, ) -> None: - tool = RequestsPutTool(requests_wrapper=mock_json_requests_wrapper) + tool = RequestsPutTool( + requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True + ) input_text = '{"url": "https://example.com", "data": {"key": "value"}}' assert tool.run(input_text) == {"response": 'put {"key": "value"}'} assert asyncio.run(tool.arun(input_text)) == {"response": 'aput {"key": "value"}'} @@ -191,7 +209,9 @@ def test_requests_put_tool_json( def test_requests_delete_tool_json( mock_json_requests_wrapper: JsonRequestsWrapper, ) -> None: - tool = RequestsDeleteTool(requests_wrapper=mock_json_requests_wrapper) + tool = RequestsDeleteTool( + requests_wrapper=mock_json_requests_wrapper, allow_dangerous_requests=True + ) assert tool.run("https://example.com") == {"response": "delete_response"} assert asyncio.run(tool.arun("https://example.com")) == { "response": "adelete_response" diff --git a/libs/langchain/langchain/agents/load_tools.py b/libs/langchain/langchain/agents/load_tools.py index ab49c39a4a0..5577c9178e1 100644 --- a/libs/langchain/langchain/agents/load_tools.py +++ b/libs/langchain/langchain/agents/load_tools.py @@ -106,23 +106,48 @@ from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper def _get_tools_requests_get() -> BaseTool: - return RequestsGetTool(requests_wrapper=TextRequestsWrapper()) + # Dangerous requests are allowed here, because there's another flag that the user + # has to provide in order to actually opt in. + # This is a private function and should not be used directly. + return RequestsGetTool( + requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True + ) def _get_tools_requests_post() -> BaseTool: - return RequestsPostTool(requests_wrapper=TextRequestsWrapper()) + # Dangerous requests are allowed here, because there's another flag that the user + # has to provide in order to actually opt in. + # This is a private function and should not be used directly. + return RequestsPostTool( + requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True + ) def _get_tools_requests_patch() -> BaseTool: - return RequestsPatchTool(requests_wrapper=TextRequestsWrapper()) + # Dangerous requests are allowed here, because there's another flag that the user + # has to provide in order to actually opt in. + # This is a private function and should not be used directly. + return RequestsPatchTool( + requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True + ) def _get_tools_requests_put() -> BaseTool: - return RequestsPutTool(requests_wrapper=TextRequestsWrapper()) + # Dangerous requests are allowed here, because there's another flag that the user + # has to provide in order to actually opt in. + # This is a private function and should not be used directly. + return RequestsPutTool( + requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True + ) def _get_tools_requests_delete() -> BaseTool: - return RequestsDeleteTool(requests_wrapper=TextRequestsWrapper()) + # Dangerous requests are allowed here, because there's another flag that the user + # has to provide in order to actually opt in. + # This is a private function and should not be used directly. + return RequestsDeleteTool( + requests_wrapper=TextRequestsWrapper(), allow_dangerous_requests=True + ) def _get_terminal() -> BaseTool: @@ -134,6 +159,15 @@ def _get_sleep() -> BaseTool: _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = { + "sleep": _get_sleep, +} + +DANGEROUS_TOOLS = { + # Tools that contain some level of risk. + # Please use with caution and read the documentation of these tools + # to understand the risks and how to mitigate them. + # Refer to https://python.langchain.com/docs/security + # for more information. "requests": _get_tools_requests_get, # preserved for backwards compatibility "requests_get": _get_tools_requests_get, "requests_post": _get_tools_requests_post, @@ -141,7 +175,6 @@ _BASE_TOOLS: Dict[str, Callable[[], BaseTool]] = { "requests_put": _get_tools_requests_put, "requests_delete": _get_tools_requests_delete, "terminal": _get_terminal, - "sleep": _get_sleep, } @@ -541,6 +574,7 @@ def load_tools( tool_names: List[str], llm: Optional[BaseLanguageModel] = None, callbacks: Callbacks = None, + allow_dangerous_tools: bool = False, **kwargs: Any, ) -> List[BaseTool]: """Load tools based on their name. @@ -566,6 +600,15 @@ def load_tools( llm: An optional language model, may be needed to initialize certain tools. callbacks: Optional callback manager or list of callback handlers. If not provided, default global callback manager will be used. + allow_dangerous_tools: Optional flag to allow dangerous tools. + Tools that contain some level of risk. + Please use with caution and read the documentation of these tools + to understand the risks and how to mitigate them. + Refer to https://python.langchain.com/docs/security + for more information. + Please note that this list may not be fully exhaustive. + It is your responsibility to understand which tools + you're using and the risks associated with them. Returns: List of tools. @@ -574,10 +617,26 @@ def load_tools( callbacks = _handle_callbacks( callback_manager=kwargs.get("callback_manager"), callbacks=callbacks ) - # print(_BASE_TOOLS) - # print(1) for name in tool_names: - if name == "requests": + if name in DANGEROUS_TOOLS and not allow_dangerous_tools: + raise ValueError( + f"{name} is a dangerous tool. You cannot use it without opting in " + "by setting allow_dangerous_tools to True. " + "Most tools have some inherit risk to them merely because they are " + 'allowed to interact with the "real world".' + "Please refer to LangChain security guidelines " + "to https://python.langchain.com/docs/security." + "Some tools have been designated as dangerous because they pose " + "risk that is not intuitively obvious. For example, a tool that " + "allows an agent to make requests to the web, can also be used " + "to make requests to a server that is only accessible from the " + "server hosting the code." + "Again, all tools carry some risk, and it's your responsibility to " + "understand which tools you're using and the risks associated with " + "them." + ) + + if name in {"requests"}: warnings.warn( "tool name `requests` is deprecated - " "please use `requests_all` or specify the requests method" @@ -590,6 +649,8 @@ def load_tools( tool_names.extend(requests_method_tools) elif name in _BASE_TOOLS: tools.append(_BASE_TOOLS[name]()) + elif name in DANGEROUS_TOOLS: + tools.append(DANGEROUS_TOOLS[name]()) elif name in _LLM_TOOLS: if llm is None: raise ValueError(f"Tool {name} requires an LLM to be provided") @@ -628,4 +689,5 @@ def get_all_tool_names() -> List[str]: + list(_EXTRA_OPTIONAL_TOOLS) + list(_EXTRA_LLM_TOOLS) + list(_LLM_TOOLS) + + list(DANGEROUS_TOOLS) ) diff --git a/libs/langchain/tests/unit_tests/agents/test_tools.py b/libs/langchain/tests/unit_tests/agents/test_tools.py index dbedaacec45..d32a57e4e42 100644 --- a/libs/langchain/tests/unit_tests/agents/test_tools.py +++ b/libs/langchain/tests/unit_tests/agents/test_tools.py @@ -71,7 +71,11 @@ def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None: """Test load_tools raises a deprecation for old callback manager kwarg.""" callback_manager = MagicMock() with pytest.warns(DeprecationWarning, match="callback_manager is deprecated"): - tools = load_tools(["requests_get"], callback_manager=callback_manager) + tools = load_tools( + ["requests_get"], + callback_manager=callback_manager, + allow_dangerous_tools=True, + ) assert len(tools) == 1 assert tools[0].callbacks == callback_manager @@ -79,7 +83,11 @@ def test_load_tools_with_callback_manager_raises_deprecation_warning() -> None: def test_load_tools_with_callbacks_is_called() -> None: """Test callbacks are called when provided to load_tools fn.""" callbacks = [FakeCallbackHandler()] - tools = load_tools(["requests_get"], callbacks=callbacks) # type: ignore + tools = load_tools( + ["requests_get"], # type: ignore + callbacks=callbacks, # type: ignore + allow_dangerous_tools=True, + ) assert len(tools) == 1 # Patch the requests.get() method to return a mock response with unittest.mock.patch(