requests wrapper (#2367)

This commit is contained in:
Harrison Chase 2023-04-03 21:57:19 -07:00 committed by GitHub
parent 10dab053b4
commit fe1eb8ca5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 128 additions and 47 deletions

View File

@ -41,7 +41,7 @@
"from langchain.agents.agent_toolkits import JsonToolkit\n", "from langchain.agents.agent_toolkits import JsonToolkit\n",
"from langchain.chains import LLMChain\n", "from langchain.chains import LLMChain\n",
"from langchain.llms.openai import OpenAI\n", "from langchain.llms.openai import OpenAI\n",
"from langchain.requests import RequestsWrapper\n", "from langchain.requests import TextRequestsWrapper\n",
"from langchain.tools.json.tool import JsonSpec" "from langchain.tools.json.tool import JsonSpec"
] ]
}, },

View File

@ -35,7 +35,7 @@
"from langchain.agents import create_openapi_agent\n", "from langchain.agents import create_openapi_agent\n",
"from langchain.agents.agent_toolkits import OpenAPIToolkit\n", "from langchain.agents.agent_toolkits import OpenAPIToolkit\n",
"from langchain.llms.openai import OpenAI\n", "from langchain.llms.openai import OpenAI\n",
"from langchain.requests import RequestsWrapper\n", "from langchain.requests import TextRequestsWrapper\n",
"from langchain.tools.json.tool import JsonSpec" "from langchain.tools.json.tool import JsonSpec"
] ]
}, },
@ -54,7 +54,7 @@
"headers = {\n", "headers = {\n",
" \"Authorization\": f\"Bearer {os.getenv('OPENAI_API_KEY')}\"\n", " \"Authorization\": f\"Bearer {os.getenv('OPENAI_API_KEY')}\"\n",
"}\n", "}\n",
"requests_wrapper=RequestsWrapper(headers=headers)\n", "requests_wrapper=TextRequestsWrapper(headers=headers)\n",
"openapi_toolkit = OpenAPIToolkit.from_llm(OpenAI(temperature=0), json_spec, requests_wrapper, verbose=True)\n", "openapi_toolkit = OpenAPIToolkit.from_llm(OpenAI(temperature=0), json_spec, requests_wrapper, verbose=True)\n",
"openapi_agent_executor = create_openapi_agent(\n", "openapi_agent_executor = create_openapi_agent(\n",
" llm=OpenAI(temperature=0),\n", " llm=OpenAI(temperature=0),\n",

View File

@ -17,7 +17,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain.utilities import RequestsWrapper" "from langchain.utilities import TextRequestsWrapper"
] ]
}, },
{ {
@ -27,7 +27,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"requests = RequestsWrapper()" "requests = TextRequestsWrapper()"
] ]
}, },
{ {

View File

@ -10,7 +10,7 @@ from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper from langchain.requests import TextRequestsWrapper
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.tools.json.tool import JsonSpec from langchain.tools.json.tool import JsonSpec
from langchain.tools.requests.tool import ( from langchain.tools.requests.tool import (
@ -25,7 +25,7 @@ from langchain.tools.requests.tool import (
class RequestsToolkit(BaseToolkit): class RequestsToolkit(BaseToolkit):
"""Toolkit for making requests.""" """Toolkit for making requests."""
requests_wrapper: RequestsWrapper requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]: def get_tools(self) -> List[BaseTool]:
"""Return a list of tools.""" """Return a list of tools."""
@ -42,7 +42,7 @@ class OpenAPIToolkit(BaseToolkit):
"""Toolkit for interacting with a OpenAPI api.""" """Toolkit for interacting with a OpenAPI api."""
json_agent: AgentExecutor json_agent: AgentExecutor
requests_wrapper: RequestsWrapper requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]: def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""
@ -59,7 +59,7 @@ class OpenAPIToolkit(BaseToolkit):
cls, cls,
llm: BaseLLM, llm: BaseLLM,
json_spec: JsonSpec, json_spec: JsonSpec,
requests_wrapper: RequestsWrapper, requests_wrapper: TextRequestsWrapper,
**kwargs: Any, **kwargs: Any,
) -> OpenAPIToolkit: ) -> OpenAPIToolkit:
"""Create json agent from llm, then initialize.""" """Create json agent from llm, then initialize."""

View File

@ -10,7 +10,7 @@ from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain from langchain.chains.pal.base import PALChain
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper from langchain.requests import TextRequestsWrapper
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.tools.bing_search.tool import BingSearchRun from langchain.tools.bing_search.tool import BingSearchRun
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
@ -42,23 +42,23 @@ def _get_python_repl() -> BaseTool:
def _get_tools_requests_get() -> BaseTool: def _get_tools_requests_get() -> BaseTool:
return RequestsGetTool(requests_wrapper=RequestsWrapper()) return RequestsGetTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_post() -> BaseTool: def _get_tools_requests_post() -> BaseTool:
return RequestsPostTool(requests_wrapper=RequestsWrapper()) return RequestsPostTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_patch() -> BaseTool: def _get_tools_requests_patch() -> BaseTool:
return RequestsPatchTool(requests_wrapper=RequestsWrapper()) return RequestsPatchTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_put() -> BaseTool: def _get_tools_requests_put() -> BaseTool:
return RequestsPutTool(requests_wrapper=RequestsWrapper()) return RequestsPutTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_delete() -> BaseTool: def _get_tools_requests_delete() -> BaseTool:
return RequestsDeleteTool(requests_wrapper=RequestsWrapper()) return RequestsDeleteTool(requests_wrapper=TextRequestsWrapper())
def _get_terminal() -> BaseTool: def _get_terminal() -> BaseTool:

View File

@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper from langchain.requests import TextRequestsWrapper
from langchain.schema import BaseLanguageModel from langchain.schema import BaseLanguageModel
@ -18,7 +18,7 @@ class APIChain(Chain, BaseModel):
api_request_chain: LLMChain api_request_chain: LLMChain
api_answer_chain: LLMChain api_answer_chain: LLMChain
requests_wrapper: RequestsWrapper = Field(exclude=True) requests_wrapper: TextRequestsWrapper = Field(exclude=True)
api_docs: str api_docs: str
question_key: str = "question" #: :meta private: question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
@ -93,7 +93,7 @@ class APIChain(Chain, BaseModel):
) -> APIChain: ) -> APIChain:
"""Load chain from just an LLM and the api docs.""" """Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt) get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt)
requests_wrapper = RequestsWrapper(headers=headers) requests_wrapper = TextRequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt) get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt)
return cls( return cls(
api_request_chain=get_request_chain, api_request_chain=get_request_chain,

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.requests import RequestsWrapper from langchain.requests import TextRequestsWrapper
DEFAULT_HEADERS = { DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501 "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501
@ -18,8 +18,8 @@ class LLMRequestsChain(Chain, BaseModel):
"""Chain that hits a URL and then uses an LLM to parse results.""" """Chain that hits a URL and then uses an LLM to parse results."""
llm_chain: LLMChain llm_chain: LLMChain
requests_wrapper: RequestsWrapper = Field( requests_wrapper: TextRequestsWrapper = Field(
default_factory=RequestsWrapper, exclude=True default_factory=TextRequestsWrapper, exclude=True
) )
text_length: int = 8000 text_length: int = 8000
requests_key: str = "requests_result" #: :meta private: requests_key: str = "requests_result" #: :meta private:

View File

@ -6,8 +6,12 @@ import requests
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
class RequestsWrapper(BaseModel): class Requests(BaseModel):
"""Lightweight wrapper around requests library.""" """Wrapper around requests to handle auth and async.
The main purpose of this wrapper is to handle authentication (by saving
headers) and enable easy async methods on the same base object.
"""
headers: Optional[Dict[str, str]] = None headers: Optional[Dict[str, str]] = None
aiosession: Optional[aiohttp.ClientSession] = None aiosession: Optional[aiohttp.ClientSession] = None
@ -18,56 +22,133 @@ class RequestsWrapper(BaseModel):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
def get(self, url: str, **kwargs: Any) -> str: def get(self, url: str, **kwargs: Any) -> requests.Response:
"""GET the URL and return the text.""" """GET the URL and return the text."""
return requests.get(url, headers=self.headers, **kwargs).text return requests.get(url, headers=self.headers, **kwargs)
def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
"""POST to the URL and return the text.""" """POST to the URL and return the text."""
return requests.post(url, json=data, headers=self.headers, **kwargs).text return requests.post(url, json=data, headers=self.headers, **kwargs)
def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
"""PATCH the URL and return the text.""" """PATCH the URL and return the text."""
return requests.patch(url, json=data, headers=self.headers, **kwargs).text return requests.patch(url, json=data, headers=self.headers, **kwargs)
def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
"""PUT the URL and return the text.""" """PUT the URL and return the text."""
return requests.put(url, json=data, headers=self.headers, **kwargs).text return requests.put(url, json=data, headers=self.headers, **kwargs)
def delete(self, url: str, **kwargs: Any) -> str: def delete(self, url: str, **kwargs: Any) -> requests.Response:
"""DELETE the URL and return the text.""" """DELETE the URL and return the text."""
return requests.delete(url, headers=self.headers, **kwargs).text return requests.delete(url, headers=self.headers, **kwargs)
async def _arequest(self, method: str, url: str, **kwargs: Any) -> str: async def _arequest(
self, method: str, url: str, **kwargs: Any
) -> aiohttp.ClientResponse:
"""Make an async request.""" """Make an async request."""
if not self.aiosession: if not self.aiosession:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(
method, url, headers=self.headers, **kwargs method, url, headers=self.headers, **kwargs
) as response: ) as response:
return await response.text() return response
else: else:
async with self.aiosession.request( async with self.aiosession.request(
method, url, headers=self.headers, **kwargs method, url, headers=self.headers, **kwargs
) as response: ) as response:
return await response.text() return response
async def aget(self, url: str, **kwargs: Any) -> str: async def aget(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""GET the URL and return the text asynchronously.""" """GET the URL and return the text asynchronously."""
return await self._arequest("GET", url, **kwargs) return await self._arequest("GET", url, **kwargs)
async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: async def apost(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> aiohttp.ClientResponse:
"""POST to the URL and return the text asynchronously.""" """POST to the URL and return the text asynchronously."""
return await self._arequest("POST", url, json=data, **kwargs) return await self._arequest("POST", url, json=data, **kwargs)
async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: async def apatch(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> aiohttp.ClientResponse:
"""PATCH the URL and return the text asynchronously.""" """PATCH the URL and return the text asynchronously."""
return await self._arequest("PATCH", url, json=data, **kwargs) return await self._arequest("PATCH", url, json=data, **kwargs)
async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: async def aput(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> aiohttp.ClientResponse:
"""PUT the URL and return the text asynchronously.""" """PUT the URL and return the text asynchronously."""
return await self._arequest("PUT", url, json=data, **kwargs) return await self._arequest("PUT", url, json=data, **kwargs)
async def adelete(self, url: str, **kwargs: Any) -> str: async def adelete(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""DELETE the URL and return the text asynchronously.""" """DELETE the URL and return the text asynchronously."""
return await self._arequest("DELETE", url, **kwargs) return await self._arequest("DELETE", url, **kwargs)
class TextRequestsWrapper(BaseModel):
"""Lightweight wrapper around requests library.
The main purpose of this wrapper is to always return a text output.
"""
headers: Optional[Dict[str, str]] = None
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def requests(self) -> Requests:
return Requests(headers=self.headers, aiosession=self.aiosession)
def get(self, url: str, **kwargs: Any) -> str:
"""GET the URL and return the text."""
return self.requests.get(url, **kwargs).text
def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""POST to the URL and return the text."""
return self.requests.post(url, json=data, headers=self.headers, **kwargs).text
def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PATCH the URL and return the text."""
return self.requests.patch(url, json=data, headers=self.headers, **kwargs).text
def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PUT the URL and return the text."""
return self.requests.put(url, json=data, headers=self.headers, **kwargs).text
def delete(self, url: str, **kwargs: Any) -> str:
"""DELETE the URL and return the text."""
return self.requests.delete(url, headers=self.headers, **kwargs).text
async def aget(self, url: str, **kwargs: Any) -> str:
"""GET the URL and return the text asynchronously."""
response = await self.requests.aget(url, **kwargs)
return await response.text()
async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""POST to the URL and return the text asynchronously."""
response = await self.requests.apost(url, data, **kwargs)
return await response.text()
async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PATCH the URL and return the text asynchronously."""
response = await self.requests.apatch(url, data, **kwargs)
return await response.text()
async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PUT the URL and return the text asynchronously."""
response = await self.requests.aput(url, data, **kwargs)
return await response.text()
async def adelete(self, url: str, **kwargs: Any) -> str:
"""DELETE the URL and return the text asynchronously."""
response = await self.requests.adelete(url, **kwargs)
return await response.text()
# For backwards compatibility
RequestsWrapper = TextRequestsWrapper

View File

@ -5,7 +5,7 @@ from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel
from langchain.requests import RequestsWrapper from langchain.requests import TextRequestsWrapper
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
@ -17,7 +17,7 @@ def _parse_input(text: str) -> Dict[str, Any]:
class BaseRequestsTool(BaseModel): class BaseRequestsTool(BaseModel):
"""Base class for requests tools.""" """Base class for requests tools."""
requests_wrapper: RequestsWrapper requests_wrapper: TextRequestsWrapper
class RequestsGetTool(BaseRequestsTool, BaseTool): class RequestsGetTool(BaseRequestsTool, BaseTool):

View File

@ -1,6 +1,6 @@
"""General utilities.""" """General utilities."""
from langchain.python import PythonREPL from langchain.python import PythonREPL
from langchain.requests import RequestsWrapper from langchain.requests import TextRequestsWrapper
from langchain.utilities.apify import ApifyWrapper from langchain.utilities.apify import ApifyWrapper
from langchain.utilities.bash import BashProcess from langchain.utilities.bash import BashProcess
from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper
@ -15,7 +15,7 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
__all__ = [ __all__ = [
"ApifyWrapper", "ApifyWrapper",
"BashProcess", "BashProcess",
"RequestsWrapper", "TextRequestsWrapper",
"PythonREPL", "PythonREPL",
"GoogleSearchAPIWrapper", "GoogleSearchAPIWrapper",
"GoogleSerperAPIWrapper", "GoogleSerperAPIWrapper",

View File

@ -8,11 +8,11 @@ import pytest
from langchain import LLMChain from langchain import LLMChain
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.requests import RequestsWrapper from langchain.requests import TextRequestsWrapper
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeRequestsChain(RequestsWrapper): class FakeRequestsChain(TextRequestsWrapper):
"""Fake requests chain just for testing purposes.""" """Fake requests chain just for testing purposes."""
output: str output: str