community[minor]: add JsonRequestsWrapper tool (#15374)

**Description:** This new feature enhances the flexibility of pipeline
integration, particularly when working with RESTful APIs.
``JsonRequestsWrapper`` allows for the decoding of JSON output, instead
of the only option for text output.

---------

Co-authored-by: Zhichao HAN <hanzhichao2000@hotmail.com>
This commit is contained in:
Zhichao HAN 2024-01-16 04:27:19 +08:00 committed by GitHub
parent d334efc848
commit 5cf06db3b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 238 additions and 42 deletions

View File

@ -113,10 +113,63 @@
"requests.get(\"https://www.google.com\")" "requests.get(\"https://www.google.com\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "4b0bf1d0",
"metadata": {},
"source": [
"If you need the output to be decoded from JSON, you can use the ``JsonRequestsWrapper``."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3f27ee3d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Type - <class 'dict'>\n",
"\n",
"Content: \n",
"```\n",
"{'count': 5707, 'name': 'jackson', 'age': 38}\n",
"```\n",
"\n",
"\n"
]
}
],
"source": [
"from langchain_community.utilities.requests import JsonRequestsWrapper\n",
"\n",
"requests = JsonRequestsWrapper()\n",
"\n",
"\n",
"rval = requests.get(\"https://api.agify.io/?name=jackson\")\n",
"\n",
"print(\n",
" f\"\"\"\n",
"\n",
"Type - {type(rval)}\n",
"\n",
"Content: \n",
"```\n",
"{rval}\n",
"```\n",
"\n",
"\"\"\"\n",
")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "3f27ee3d", "id": "52a1aa15",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
@ -138,7 +191,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.2" "version": "3.10.13"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,7 +1,7 @@
# flake8: noqa # flake8: noqa
"""Tools for making requests to an API endpoint.""" """Tools for making requests to an API endpoint."""
import json import json
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Union
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -9,7 +9,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_community.utilities.requests import TextRequestsWrapper from langchain_community.utilities.requests import GenericRequestsWrapper
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -26,7 +26,7 @@ def _clean_url(url: str) -> str:
class BaseRequestsTool(BaseModel): class BaseRequestsTool(BaseModel):
"""Base class for requests tools.""" """Base class for requests tools."""
requests_wrapper: TextRequestsWrapper requests_wrapper: GenericRequestsWrapper
class RequestsGetTool(BaseRequestsTool, BaseTool): class RequestsGetTool(BaseRequestsTool, BaseTool):
@ -37,7 +37,7 @@ class RequestsGetTool(BaseRequestsTool, BaseTool):
def _run( def _run(
self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool.""" """Run the tool."""
return self.requests_wrapper.get(_clean_url(url)) return self.requests_wrapper.get(_clean_url(url))
@ -45,7 +45,7 @@ class RequestsGetTool(BaseRequestsTool, BaseTool):
self, self,
url: str, url: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
return await self.requests_wrapper.aget(_clean_url(url)) return await self.requests_wrapper.aget(_clean_url(url))
@ -64,7 +64,7 @@ class RequestsPostTool(BaseRequestsTool, BaseTool):
def _run( def _run(
self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool.""" """Run the tool."""
try: try:
data = _parse_input(text) data = _parse_input(text)
@ -76,7 +76,7 @@ class RequestsPostTool(BaseRequestsTool, BaseTool):
self, self,
text: str, text: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
try: try:
data = _parse_input(text) data = _parse_input(text)
@ -101,7 +101,7 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool):
def _run( def _run(
self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool.""" """Run the tool."""
try: try:
data = _parse_input(text) data = _parse_input(text)
@ -113,7 +113,7 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool):
self, self,
text: str, text: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
try: try:
data = _parse_input(text) data = _parse_input(text)
@ -138,7 +138,7 @@ class RequestsPutTool(BaseRequestsTool, BaseTool):
def _run( def _run(
self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool.""" """Run the tool."""
try: try:
data = _parse_input(text) data = _parse_input(text)
@ -150,7 +150,7 @@ class RequestsPutTool(BaseRequestsTool, BaseTool):
self, self,
text: str, text: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
try: try:
data = _parse_input(text) data = _parse_input(text)
@ -171,7 +171,7 @@ class RequestsDeleteTool(BaseRequestsTool, BaseTool):
self, self,
url: str, url: str,
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool.""" """Run the tool."""
return self.requests_wrapper.delete(_clean_url(url)) return self.requests_wrapper.delete(_clean_url(url))
@ -179,6 +179,6 @@ class RequestsDeleteTool(BaseRequestsTool, BaseTool):
self, self,
url: str, url: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str: ) -> Union[str, Dict[str, Any]]:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
return await self.requests_wrapper.adelete(_clean_url(url)) return await self.requests_wrapper.adelete(_clean_url(url))

View File

@ -1,10 +1,11 @@
"""Lightweight wrapper around requests library, with async support.""" """Lightweight wrapper around requests library, with async support."""
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional from typing import Any, AsyncGenerator, Dict, Literal, Optional, Union
import aiohttp import aiohttp
import requests import requests
from langchain_core.pydantic_v1 import BaseModel, Extra from langchain_core.pydantic_v1 import BaseModel, Extra
from requests import Response
class Requests(BaseModel): class Requests(BaseModel):
@ -108,15 +109,13 @@ class Requests(BaseModel):
yield response yield response
class TextRequestsWrapper(BaseModel): class GenericRequestsWrapper(BaseModel):
"""Lightweight wrapper around requests library. """Lightweight wrapper around requests library."""
The main purpose of this wrapper is to always return a text output.
"""
headers: Optional[Dict[str, str]] = None headers: Optional[Dict[str, str]] = None
aiosession: Optional[aiohttp.ClientSession] = None aiosession: Optional[aiohttp.ClientSession] = None
auth: Optional[Any] = None auth: Optional[Any] = None
response_content_type: Literal["text", "json"] = "text"
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -130,50 +129,96 @@ class TextRequestsWrapper(BaseModel):
headers=self.headers, aiosession=self.aiosession, auth=self.auth headers=self.headers, aiosession=self.aiosession, auth=self.auth
) )
def get(self, url: str, **kwargs: Any) -> str: def _get_resp_content(self, response: Response) -> Union[str, Dict[str, Any]]:
if self.response_content_type == "text":
return response.text
elif self.response_content_type == "json":
return response.json()
else:
raise ValueError(f"Invalid return type: {self.response_content_type}")
def _aget_resp_content(
self, response: aiohttp.ClientResponse
) -> Union[str, Dict[str, Any]]:
if self.response_content_type == "text":
return response.text()
elif self.response_content_type == "json":
return response.json()
else:
raise ValueError(f"Invalid return type: {self.response_content_type}")
def get(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""GET the URL and return the text.""" """GET the URL and return the text."""
return self.requests.get(url, **kwargs).text return self._get_resp_content(self.requests.get(url, **kwargs))
def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: def post(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""POST to the URL and return the text.""" """POST to the URL and return the text."""
return self.requests.post(url, data, **kwargs).text return self._get_resp_content(self.requests.post(url, data, **kwargs))
def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: def patch(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""PATCH the URL and return the text.""" """PATCH the URL and return the text."""
return self.requests.patch(url, data, **kwargs).text return self._get_resp_content(self.requests.patch(url, data, **kwargs))
def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: def put(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""PUT the URL and return the text.""" """PUT the URL and return the text."""
return self.requests.put(url, data, **kwargs).text return self._get_resp_content(self.requests.put(url, data, **kwargs))
def delete(self, url: str, **kwargs: Any) -> str: def delete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""DELETE the URL and return the text.""" """DELETE the URL and return the text."""
return self.requests.delete(url, **kwargs).text return self._get_resp_content(self.requests.delete(url, **kwargs))
async def aget(self, url: str, **kwargs: Any) -> str: async def aget(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""GET the URL and return the text asynchronously.""" """GET the URL and return the text asynchronously."""
async with self.requests.aget(url, **kwargs) as response: async with self.requests.aget(url, **kwargs) as response:
return await response.text() return await self._aget_resp_content(response)
async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: async def apost(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""POST to the URL and return the text asynchronously.""" """POST to the URL and return the text asynchronously."""
async with self.requests.apost(url, data, **kwargs) as response: async with self.requests.apost(url, data, **kwargs) as response:
return await response.text() return await self._aget_resp_content(response)
async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: async def apatch(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""PATCH the URL and return the text asynchronously.""" """PATCH the URL and return the text asynchronously."""
async with self.requests.apatch(url, data, **kwargs) as response: async with self.requests.apatch(url, data, **kwargs) as response:
return await response.text() return await self._aget_resp_content(response)
async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: async def aput(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""PUT the URL and return the text asynchronously.""" """PUT the URL and return the text asynchronously."""
async with self.requests.aput(url, data, **kwargs) as response: async with self.requests.aput(url, data, **kwargs) as response:
return await response.text() return await self._aget_resp_content(response)
async def adelete(self, url: str, **kwargs: Any) -> str: async def adelete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""DELETE the URL and return the text asynchronously.""" """DELETE the URL and return the text asynchronously."""
async with self.requests.adelete(url, **kwargs) as response: async with self.requests.adelete(url, **kwargs) as response:
return await response.text() return await self._aget_resp_content(response)
class JsonRequestsWrapper(GenericRequestsWrapper):
"""Lightweight wrapper around requests library, with async support.
The main purpose of this wrapper is to always return a json output."""
response_content_type: Literal["text", "json"] = "json"
class TextRequestsWrapper(GenericRequestsWrapper):
"""Lightweight wrapper around requests library, with async support.
The main purpose of this wrapper is to always return a text output."""
response_content_type: Literal["text", "json"] = "text"
# For backwards compatibility # For backwards compatibility

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import json
from typing import Any, Dict from typing import Any, Dict
import pytest import pytest
@ -11,7 +12,10 @@ from langchain_community.tools.requests.tool import (
RequestsPutTool, RequestsPutTool,
_parse_input, _parse_input,
) )
from langchain_community.utilities.requests import TextRequestsWrapper from langchain_community.utilities.requests import (
JsonRequestsWrapper,
TextRequestsWrapper,
)
class _MockTextRequestsWrapper(TextRequestsWrapper): class _MockTextRequestsWrapper(TextRequestsWrapper):
@ -98,3 +102,97 @@ def test_requests_delete_tool(mock_requests_wrapper: TextRequestsWrapper) -> Non
tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper) tool = RequestsDeleteTool(requests_wrapper=mock_requests_wrapper)
assert tool.run("https://example.com") == "delete_response" assert tool.run("https://example.com") == "delete_response"
assert asyncio.run(tool.arun("https://example.com")) == "adelete_response" assert asyncio.run(tool.arun("https://example.com")) == "adelete_response"
class _MockJsonRequestsWrapper(JsonRequestsWrapper):
@staticmethod
def get(url: str, **kwargs: Any) -> Dict[str, Any]:
return {"response": "get_response"}
@staticmethod
async def aget(url: str, **kwargs: Any) -> Dict[str, Any]:
return {"response": "aget_response"}
@staticmethod
def post(url: str, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
return {"response": f"post {json.dumps(data)}"}
@staticmethod
async def apost(url: str, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
return {"response": f"apost {json.dumps(data)}"}
@staticmethod
def patch(url: str, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
return {"response": f"patch {json.dumps(data)}"}
@staticmethod
async def apatch(url: str, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
return {"response": f"apatch {json.dumps(data)}"}
@staticmethod
def put(url: str, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
return {"response": f"put {json.dumps(data)}"}
@staticmethod
async def aput(url: str, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
return {"response": f"aput {json.dumps(data)}"}
@staticmethod
def delete(url: str, **kwargs: Any) -> Dict[str, Any]:
return {"response": "delete_response"}
@staticmethod
async def adelete(url: str, **kwargs: Any) -> Dict[str, Any]:
return {"response": "adelete_response"}
@pytest.fixture
def mock_json_requests_wrapper() -> JsonRequestsWrapper:
return _MockJsonRequestsWrapper()
def test_requests_get_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsGetTool(requests_wrapper=mock_json_requests_wrapper)
assert tool.run("https://example.com") == {"response": "get_response"}
assert asyncio.run(tool.arun("https://example.com")) == {
"response": "aget_response"
}
def test_requests_post_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsPostTool(requests_wrapper=mock_json_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == {"response": 'post {"key": "value"}'}
assert asyncio.run(tool.arun(input_text)) == {"response": 'apost {"key": "value"}'}
def test_requests_patch_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsPatchTool(requests_wrapper=mock_json_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == {"response": 'patch {"key": "value"}'}
assert asyncio.run(tool.arun(input_text)) == {"response": 'apatch {"key": "value"}'}
def test_requests_put_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsPutTool(requests_wrapper=mock_json_requests_wrapper)
input_text = '{"url": "https://example.com", "data": {"key": "value"}}'
assert tool.run(input_text) == {"response": 'put {"key": "value"}'}
assert asyncio.run(tool.arun(input_text)) == {"response": 'aput {"key": "value"}'}
def test_requests_delete_tool_json(
mock_json_requests_wrapper: JsonRequestsWrapper,
) -> None:
tool = RequestsDeleteTool(requests_wrapper=mock_json_requests_wrapper)
assert tool.run("https://example.com") == {"response": "delete_response"}
assert asyncio.run(tool.arun("https://example.com")) == {
"response": "adelete_response"
}