Compare commits

...

1 Commits

Author SHA1 Message Date
Chester Curme
02dd572763 add chain 2024-07-25 18:23:48 -04:00
2 changed files with 167 additions and 3 deletions

View File

@@ -3,19 +3,29 @@
from __future__ import annotations
import json
from typing import Any, Dict, List, NamedTuple, Optional, cast
from typing import Any, Dict, List, Literal, NamedTuple, Optional, cast
import aiohttp
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, chain
from langchain_core.tools import BaseTool
from requests import Response
from langchain_community.agent_toolkits.openapi.toolkit import RequestsToolkit
from langchain_community.tools.openapi.utils.api_models import APIOperation
from langchain_community.utilities.requests import Requests
from langchain_community.utilities.requests import Requests, TextRequestsWrapper
class _ParamMapping(NamedTuple):
@@ -228,3 +238,84 @@ class OpenAPIEndpointChain(Chain, BaseModel):
callbacks=callbacks,
**kwargs,
)
def _prepend_system_message(query: str, system_message: str) -> List[BaseMessage]:
return [SystemMessage(system_message), HumanMessage(query)]
def _invoke_llm(messages: List[BaseMessage], llm: Runnable) -> List[BaseMessage]:
return messages + [llm.invoke(messages)]
def _execute_tools(
messages: List[BaseMessage], tool_name_to_tool: Dict[str, BaseTool]
) -> List[BaseMessage]:
"""Execute tool and return result as a string."""
output_messages = []
ai_message = next(
message for message in messages[::-1] if isinstance(message, AIMessage)
)
for tool_call in ai_message.tool_calls:
selected_tool = tool_name_to_tool[tool_call["name"]]
tool_msg = selected_tool.invoke(tool_call)
output_messages.append(tool_msg)
return messages + output_messages
def create_openapi_endpoint_chain(
llm: BaseChatModel,
api_spec: str,
system_message: Optional[str] = None,
allow_dangerous_requests: bool = False,
headers: Optional[Dict[str, str]] = None,
aiosession: Optional[aiohttp.ClientSession] = None,
auth: Optional[Any] = None,
response_content_type: Literal["text", "json"] = "text",
verify: bool = True,
supported_tools: Optional[List[str]] = None,
) -> Runnable:
requests_wrapper = TextRequestsWrapper(
headers=headers,
aiosession=aiosession,
auth=auth,
response_content_type=response_content_type,
verify=verify,
)
toolkit = RequestsToolkit(
requests_wrapper=requests_wrapper,
allow_dangerous_requests=allow_dangerous_requests,
)
if supported_tools is None:
supported_tools = [
"requests_get",
"requests_post",
"requests_patch",
"requests_put",
"requests_delete",
]
tools = [tool for tool in toolkit.get_tools() if tool.name in supported_tools]
llm_with_tools = llm.bind_tools(tools)
tool_name_to_tool = {tool.name: tool for tool in tools}
if system_message is None:
system_message = """
You have access to an API to help answer user queries.
Here is documentation on the API:
{api_spec}
"""
system_message = system_message.format(api_spec=api_spec)
@chain
def prepend_system_message(query: str) -> List[BaseMessage]:
return _prepend_system_message(query, system_message)
@chain
def invoke_llm(messages: List[BaseMessage]) -> List[BaseMessage]:
return _invoke_llm(messages, llm_with_tools)
@chain
def execute_tools(messages: List[BaseMessage]) -> List[BaseMessage]:
return _execute_tools(messages, tool_name_to_tool)
return prepend_system_message | invoke_llm | execute_tools | llm

View File

@@ -0,0 +1,73 @@
from typing import Any, Dict, Union
import pytest
import requests
import yaml
from langchain_core.messages import AIMessage
from langchain_community.chains.openapi.chain import create_openapi_endpoint_chain
def _get_schema(response_json: Union[dict, list]) -> dict:
if isinstance(response_json, list):
response_json = response_json[0] if response_json else {}
return {key: type(value).__name__ for key, value in response_json.items()}
def _get_api_spec() -> str:
base_url = "https://jsonplaceholder.typicode.com"
endpoints = [
"/posts",
"/comments",
]
common_query_parameters = [
{
"name": "_limit",
"in": "query",
"required": False,
"schema": {"type": "integer", "example": 2},
"description": "Limit the number of results",
}
]
openapi_spec: Dict[str, Any] = {
"openapi": "3.0.0",
"info": {"title": "JSONPlaceholder API", "version": "1.0.0"},
"servers": [{"url": base_url}],
"paths": {},
}
# Iterate over the endpoints to construct the paths
for endpoint in endpoints:
response = requests.get(base_url + endpoint)
if response.status_code == 200:
schema = _get_schema(response.json())
openapi_spec["paths"][endpoint] = {
"get": {
"summary": f"Get {endpoint[1:]}",
"parameters": common_query_parameters,
"responses": {
"200": {
"description": "Successful response",
"content": {
"application/json": {
"schema": {"type": "object", "properties": schema}
}
},
}
},
}
}
return yaml.dump(openapi_spec, sort_keys=False)
@pytest.mark.requires("langchain_openai")
def test_create_openapi_endpoint_chain() -> None:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
api_spec = _get_api_spec()
chain = create_openapi_endpoint_chain(llm, api_spec, allow_dangerous_requests=True)
result = chain.invoke("What are the titles of the top two posts?")
assert isinstance(result, AIMessage)
assert "sunt aut facere" in result.content and "qui est esse" in result.content