Support oai tool call (#13110)

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
Harrison Chase 2023-11-09 07:29:29 -08:00 committed by GitHub
parent c52725bdc5
commit 5f38770161
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1348 additions and 82 deletions

View File

@ -0,0 +1,210 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2def22ea",
"metadata": {},
"source": [
"# Extration with OpenAI Tools\n",
"\n",
"Performing extraction has never been easier! OpenAI's tool calling ability is the perfect thing to use as it allows for extracting multiple different elements from text that are different types. \n",
"\n",
"Models after 1106 use tools and support \"parallel function calling\" which makes this super easy."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5c628496",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.pydantic_v1 import BaseModel\n",
"from typing import Optional, List\n",
"from langchain.chains.openai_tools import create_extraction_chain_pydantic"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "afe9657b",
"metadata": {},
"outputs": [],
"source": [
"# Make sure to use a recent model that supports tools\n",
"model = ChatOpenAI(model=\"gpt-3.5-turbo-1106\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bc0ca3b6",
"metadata": {},
"outputs": [],
"source": [
"# Pydantic is an easy way to define a schema\n",
"class Person(BaseModel):\n",
" \"\"\"Information about people to extract.\"\"\"\n",
" name: str\n",
" age: Optional[int] = None"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2036af68",
"metadata": {},
"outputs": [],
"source": [
"chain = create_extraction_chain_pydantic(Person, model)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1748ad21",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Person(name='jane', age=2), Person(name='bob', age=3)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"input\": \"jane is 2 and bob is 3\"})"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c8262ce5",
"metadata": {},
"outputs": [],
"source": [
"# Let's define another element\n",
"class Class(BaseModel):\n",
" \"\"\"Information about classes to extract.\"\"\"\n",
" teacher: str\n",
" students: List[str]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "4973c104",
"metadata": {},
"outputs": [],
"source": [
"chain = create_extraction_chain_pydantic([Person, Class], model)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e976a15e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Person(name='jane', age=2),\n",
" Person(name='bob', age=3),\n",
" Class(teacher='Mrs Sampson', students=['jane', 'bob'])]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.invoke({\"input\": \"jane is 2 and bob is 3 and they are in Mrs Sampson's class\"})"
]
},
{
"cell_type": "markdown",
"id": "6575a7d6",
"metadata": {},
"source": [
"## Under the hood\n",
"\n",
"Under the hood, this is a simple chain:"
]
},
{
"cell_type": "markdown",
"id": "b8ba83e5",
"metadata": {},
"source": [
"```python\n",
"from langchain.output_parsers import PydanticToolsParser\n",
"from langchain.utils.openai_functions import convert_pydantic_to_openai_function\n",
"from langchain.schema.runnable import Runnable\n",
"from typing import Union, List, Type, Optional\n",
"from langchain.pydantic_v1 import BaseModel\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema.messages import SystemMessage\n",
"from langchain.schema.language_model import BaseLanguageModel\n",
"_EXTRACTION_TEMPLATE = \"\"\"Extract and save the relevant entities mentioned \\\n",
"in the following passage together with their properties.\n",
"\n",
"If a property is not present and is not required in the function parameters, do not include it in the output.\"\"\" # noqa: E501\n",
"\n",
"\n",
"def create_extraction_chain_pydantic(\n",
" pydantic_schemas: Union[List[Type[BaseModel]], Type[BaseModel]],\n",
" llm: BaseLanguageModel,\n",
" system_message: Optional[str] = _EXTRACTION_TEMPLATE,\n",
") -> Runnable:\n",
" if not isinstance(pydantic_schemas, list):\n",
" pydantic_schemas = [pydantic_schemas]\n",
" prompt = ChatPromptTemplate.from_messages({\n",
" (\"system\", _EXTRACTION_TEMPLATE),\n",
" (\"user\", \"{input}\")\n",
" })\n",
" functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]\n",
" tools = [{\"type\": \"function\", \"function\": d} for d in functiMons]\n",
" model = llm.bind(tools=tools)\n",
" chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas)\n",
" return chain\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2eac6b68",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -12,6 +12,19 @@
"Suppose we have a simple prompt + model sequence:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "950297ed-2d67-4091-8ea7-1d412d259d04",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough"
]
},
{
"cell_type": "code",
"execution_count": 11,
@ -37,11 +50,6 @@
}
],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
@ -105,31 +113,29 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"id": "f66a0fe4-fde0-4706-8863-d60253f211c7",
"metadata": {},
"outputs": [],
"source": [
"functions = [\n",
" {\n",
" \"name\": \"solver\",\n",
" \"description\": \"Formulates and solves an equation\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"equation\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The algebraic expression of the equation\",\n",
" },\n",
" \"solution\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The solution to the equation\",\n",
" },\n",
"function = {\n",
" \"name\": \"solver\",\n",
" \"description\": \"Formulates and solves an equation\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"equation\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The algebraic expression of the equation\",\n",
" },\n",
" \"solution\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The solution to the equation\",\n",
" },\n",
" \"required\": [\"equation\", \"solution\"],\n",
" },\n",
" }\n",
"]"
" \"required\": [\"equation\", \"solution\"],\n",
" },\n",
"}"
]
},
{
@ -161,19 +167,70 @@
" ]\n",
")\n",
"model = ChatOpenAI(model=\"gpt-4\", temperature=0).bind(\n",
" function_call={\"name\": \"solver\"}, functions=functions\n",
" function_call={\"name\": \"solver\"}, functions=[function]\n",
")\n",
"runnable = {\"equation_statement\": RunnablePassthrough()} | prompt | model\n",
"runnable.invoke(\"x raised to the third plus seven equals 12\")"
]
},
{
"cell_type": "markdown",
"id": "f07d7528-9269-4d6f-b12e-3669592a9e03",
"metadata": {},
"source": [
"## Attaching OpenAI tools"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "2cdeeb4c-0c1f-43da-bd58-4f591d9e0671",
"metadata": {},
"outputs": [],
"source": []
"source": [
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
" },\n",
" \"required\": [\"location\"],\n",
" },\n",
" },\n",
" }\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2b65beab-48bb-46ff-a5a4-ef8ac95a513c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_zHN0ZHwrxM7nZDdqTp6dkPko', 'function': {'arguments': '{\"location\": \"San Francisco, CA\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}, {'id': 'call_aqdMm9HBSlFW9c9rqxTa7eQv', 'function': {'arguments': '{\"location\": \"New York, NY\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}, {'id': 'call_cx8E567zcLzYV2WSWVgO63f1', 'function': {'arguments': '{\"location\": \"Los Angeles, CA\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}]})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = ChatOpenAI(model=\"gpt-3.5-turbo-1106\").bind(tools=tools)\n",
"model.invoke(\"What's the weather in SF, NYC and LA?\")"
]
}
],
"metadata": {

View File

@ -140,7 +140,7 @@
"source": [
"from langchain import hub\n",
"from langchain.agents import AgentExecutor\n",
"from langchain.agents.format_scratchpad import format_to_openai_functions\n",
"from langchain.agents.format_scratchpad import format_to_openai_function_messages\n",
"from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.tools import DuckDuckGoSearchResults\n",
@ -165,7 +165,7 @@
"runnable_agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_functions(\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" }\n",
@ -335,7 +335,7 @@
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.agents import AgentType, initialize_agent, load_tools, AgentExecutor\n",
"from langchain.agents.format_scratchpad import format_to_openai_functions\n",
"from langchain.agents.format_scratchpad import format_to_openai_function_messages\n",
"from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser\n",
"from langchain.tools.render import format_tool_to_openai_function\n",
"from langchain import hub\n",
@ -351,7 +351,7 @@
" runnable_agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_functions(\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" }\n",

View File

@ -143,7 +143,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents.format_scratchpad import format_to_openai_functions\n",
"from langchain.agents.format_scratchpad import format_to_openai_function_messages\n",
"from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser"
]
},
@ -157,7 +157,7 @@
"agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_functions(\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" }\n",

View File

@ -115,9 +115,7 @@
"cell_type": "code",
"execution_count": 6,
"id": "ba8e4cbe",
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [
{
"name": "stdout",
@ -254,9 +252,7 @@
"cell_type": "code",
"execution_count": 19,
"id": "4362ebc7",
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [
{
"name": "stdout",
@ -458,7 +454,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -0,0 +1,222 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e10aa932",
"metadata": {},
"source": [
"# OpenAI tools\n",
"\n",
"With LCEL we can easily construc"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "ec89be68",
"metadata": {},
"outputs": [],
"source": [
"# ! pip install -U openai duckduckgo-search"
]
},
{
"cell_type": "markdown",
"id": "82787d8d",
"metadata": {},
"source": [
"## Initialize tools\n",
"\n",
"We will first create some tools we can use"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b812b982",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import initialize_agent, AgentType, Tool\n",
"from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages\n",
"from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain.tools import DuckDuckGoSearchRun\n",
"from langchain.tools.render import format_tool_to_openai_tool"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "23fc0aa6",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-1106\")\n",
"tools = [DuckDuckGoSearchRun()]"
]
},
{
"cell_type": "markdown",
"id": "39c3ba21",
"metadata": {},
"source": [
"## Using LCEL\n",
"\n",
"We will first use LangChain Expression Language to create this agent"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "55292bed",
"metadata": {},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You are a helpful assistant\"),\n",
" (\"user\", \"{input}\"),\n",
" MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "552421b3",
"metadata": {},
"outputs": [],
"source": [
"llm_with_tools = llm.bind(tools=[format_tool_to_openai_tool(t)for t in tools])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bf514eb4",
"metadata": {},
"outputs": [],
"source": [
"agent = (\n",
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" \"agent_scratchpad\": lambda x: format_to_openai_tool_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" }\n",
" | prompt\n",
" | llm_with_tools\n",
" | OpenAIToolsAgentOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5125573e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import AgentExecutor"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "bdc7e506",
"metadata": {},
"outputs": [],
"source": [
"agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2cd65218",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `duckduckgo_search` with `weather in Los Angeles today`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3m60 °F like 60° Clear N 0 Today's temperature is forecast to be NEARLY THE SAME as yesterday. Radar Satellite WunderMap |Nexrad Today Wed 11/08 High 78 °F 0% Precip. / 0.00 in Sunny. High 78F.... Current Conditions Radar Forecasts Rivers and Lakes Climate and Past Weather Local Programs Fire Weather Show Caption Click a location below for detailed forecast. Last Map Update: Tue, Nov. 7, 2023 at 5:03:23 pm PST Watches, Warnings & Advisories Zoom Out Gale Warning Small Craft Advisory Wind Advisory Fire Weather Watch LOS ANGELES (KABC) -- Southern California will see mostly clear skies and mild temperatures on Monday. Los Angeles and Orange counties will see a few clouds in the morning, but they'll clear up... Storm No. 1: The first storm on Saturday will move quickly through Southern California, with most of the rain falling Saturday night. Light rain is possible late Friday for the Ventura County... Weather in Los Angeles today, California. Friday, 13 October 2023 . Day 1:00 PM +77°F . Broken cloud sky\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `duckduckgo_search` with `weather in New York City today`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mToday Hourly 10-Day Calendar History Wundermap access_time 11:38 PM EDT on October 31, 2023 (GMT -4) | Updated 22 hours ago --° | 49° 56 °F like 56° Rain Shower N 5 Gusts 7mph Radar Satellite... 49°F 9°C More Information: Local Forecast Office More Local Wx 3 Day History Mobile Weather Hourly Weather Forecast Extended Forecast for New York NY Similar City Names Overnight Mostly Cloudy Low: 48 °F Saturday Partly Sunny High: 58 °F Saturday Night Mostly Cloudy Low: 48 °F Sunday Mostly Sunny High: 64 °F Sunday Night Mostly Clear Low: 45 °F 13°C Get Detailed info Tonight Partly Cloudy Low: 47°F Sunday Mostly Sunny High: 62°F change location New York, NY Weather Forecast Office NWS Forecast Office New York, NY Weather.gov > New York, NY Current Hazards Current Conditions Radar Forecasts Rivers and Lakes Climate and Past Weather Local Programs The weather today in New York City will be comfortable with temperatures around 68°F. During the evening and night time the temperatures will drop to 48 ° F. For deep dive information check out our hourly weather forecast for today down the page. Temperature. 68 ° / 48 °. Chance of rain. 0. %. Today's Weather - New York, NY Nov 02, 2023 3:32 AM Churchill School -- Feels like -- Hi -- Lo -- -- Live Radar Weather Radar Map WEATHER DETAILS New York, NY Windchill -- Daily Rain -- Dew Point -- Monthly Rain -- Humidity -- Avg. Wind -- Pressure -- Wind Gust -- Sunrise -- Moon -- Sunset -- UV Index Low WEATHER FORECAST New York, NY\u001b[0m\u001b[32;1m\u001b[1;3m\n",
"Invoking: `duckduckgo_search` with `weather in San Francisco today`\n",
"\n",
"\n",
"\u001b[0m\u001b[36;1m\u001b[1;3mWeather Underground provides local & long-range weather forecasts, weatherreports, maps & tropical weather conditions for the San Francisco area. ... Today Wed 11/08 High 67 ... Radar Forecasts Rivers and Lakes Climate and Past Weather Local Programs Click a location below for detailed forecast. Last Map Update: Wed, Nov. 8, 2023 at 5:03:31 am PST Watches, Warnings & Advisories Zoom Out Small Craft Advisory Frost Advisory Text Product Selector (Selected product opens in current window) Radar Forecasts Rivers and Lakes Climate and Past Weather Local Programs Rain Continues Monday Show Caption Click a location below for detailed forecast. Last Map Update: Tue, Nov. 7, 2023 at 1:05:27 am PST Watches, Warnings & Advisories Zoom Out Winter Weather Advisory Small Craft Advisory Gale Watch Today's and tonight's San Francisco, CA weather forecast, weather conditions and Doppler radar from The Weather Channel and Weather.com Nov 8, 2023. The National Weather Service forecast for the greater San Francisco Bay Area on Wednesday calls for sunny skies, with clouds expected to increase over some areas around the bay during ...\u001b[0m\u001b[32;1m\u001b[1;3mHere's the weather for today in Los Angeles, New York City, and San Francisco:\n",
"\n",
"- Los Angeles: The temperature is around 60°F with clear skies and mild temperatures. There are no precipitation forecasts for today.\n",
"- New York City: The temperature is around 68°F during the day and will drop to 48°F in the evening and night. There is no chance of rain forecasted for today.\n",
"- San Francisco: The temperature is expected to reach a high of 67°F with sunny skies. Clouds are expected to increase over some areas around the bay during the day.\n",
"\n",
"If you need more detailed information, feel free to ask!\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'input': \"What's the weather in LA, NYC, and SF today\",\n",
" 'output': \"Here's the weather for today in Los Angeles, New York City, and San Francisco:\\n\\n- Los Angeles: The temperature is around 60°F with clear skies and mild temperatures. There are no precipitation forecasts for today.\\n- New York City: The temperature is around 68°F during the day and will drop to 48°F in the evening and night. There is no chance of rain forecasted for today.\\n- San Francisco: The temperature is expected to reach a high of 67°F with sunny skies. Clouds are expected to increase over some areas around the bay during the day.\\n\\nIf you need more detailed information, feel free to ask!\"}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent_executor.invoke(\n",
" {\n",
" \"input\": \"What's the weather in LA, NYC, and SF today\"\n",
" }\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c750177f-af6c-4617-b2aa-69e3f1153a3e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -205,7 +205,7 @@
"\n",
"- prompt: a simple prompt with placeholders for the user's question and then the `agent_scratchpad` (any intermediate steps)\n",
"- tools: we can attach the tools and `Response` format to the LLM as functions\n",
"- format scratchpad: in order to format the `agent_scratchpad` from intermediate steps, we will use the standard `format_to_openai_functions`. This takes intermediate steps and formats them as AIMessages and FunctionMessages.\n",
"- format scratchpad: in order to format the `agent_scratchpad` from intermediate steps, we will use the standard `format_to_openai_function_messages`. This takes intermediate steps and formats them as AIMessages and FunctionMessages.\n",
"- output parser: we will use our custom parser above to parse the response of the LLM\n",
"- AgentExecutor: we will use the standard AgentExecutor to run the loop of agent-tool-agent-tool..."
]
@ -220,7 +220,7 @@
"from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.tools.render import format_tool_to_openai_function\n",
"from langchain.agents.format_scratchpad import format_to_openai_functions\n",
"from langchain.agents.format_scratchpad import format_to_openai_function_messages\n",
"from langchain.agents import AgentExecutor"
]
},
@ -278,7 +278,7 @@
" {\n",
" \"input\": lambda x: x[\"input\"],\n",
" # Format agent scratchpad from intermediate steps\n",
" \"agent_scratchpad\": lambda x: format_to_openai_functions(\n",
" \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
" x[\"intermediate_steps\"]\n",
" ),\n",
" }\n",

View File

@ -157,11 +157,11 @@ We will import two last utility functions: a component for formatting intermedia
```python
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
agent = {
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps'])
"agent_scratchpad": lambda x: format_to_openai_function_messages(x['intermediate_steps'])
} | prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser()
```
@ -287,7 +287,7 @@ We can then put it all together!
```python
agent = {
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps']),
"agent_scratchpad": lambda x: format_to_openai_function_messages(x['intermediate_steps']),
"chat_history": lambda x: x["chat_history"]
} | prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser()
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

View File

@ -25,6 +25,7 @@ from langchain.schema.messages import (
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
@ -54,15 +55,18 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
additional_kwargs["function_call"] = dict(_dict["function_call"])
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system":
return SystemMessage(content=_dict["content"])
elif role == "function":
return FunctionMessage(content=_dict["content"], name=_dict["name"])
elif role == "tool":
return ToolMessage(content=_dict["content"], tool_call_id=_dict["tool_call_id"])
else:
return ChatMessage(content=_dict["content"], role=role)
@ -88,6 +92,11 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
# If function call only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
# If tool calls only, content is None not empty string
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
@ -96,6 +105,12 @@ def convert_message_to_dict(message: BaseMessage) -> dict:
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise TypeError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:

View File

@ -327,10 +327,21 @@ class AgentOutputParser(BaseOutputParser):
"""Parse text into agent action/finish."""
class RunnableAgent(BaseSingleActionAgent):
class MultiActionAgentOutputParser(BaseOutputParser):
"""Base class for parsing agent output into agent actions/finish."""
@abstractmethod
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
"""Parse text into agent actions/finish."""
class RunnableAgent(BaseMultiActionAgent):
"""Agent powered by runnables."""
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
runnable: Union[
Runnable[dict, Union[AgentAction, AgentFinish]],
Runnable[dict, Union[List[AgentAction], AgentFinish]],
]
"""Runnable to call to get agent action."""
_input_keys: List[str] = []
"""Input keys."""
@ -359,7 +370,10 @@ class RunnableAgent(BaseSingleActionAgent):
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
) -> Union[
List[AgentAction],
AgentFinish,
]:
"""Given input, decided what to do.
Args:
@ -373,6 +387,8 @@ class RunnableAgent(BaseSingleActionAgent):
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = self.runnable.invoke(inputs, config={"callbacks": callbacks})
if isinstance(output, AgentAction):
output = [output]
return output
async def aplan(
@ -380,7 +396,10 @@ class RunnableAgent(BaseSingleActionAgent):
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
) -> Union[
List[AgentAction],
AgentFinish,
]:
"""Given input, decided what to do.
Args:
@ -394,6 +413,8 @@ class RunnableAgent(BaseSingleActionAgent):
"""
inputs = {**kwargs, **{"intermediate_steps": intermediate_steps}}
output = await self.runnable.ainvoke(inputs, config={"callbacks": callbacks})
if isinstance(output, AgentAction):
output = [output]
return output

View File

@ -8,12 +8,14 @@ differently before passing them into the LLM.
from langchain.agents.format_scratchpad.log import format_log_to_str
from langchain.agents.format_scratchpad.log_to_messages import format_log_to_messages
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_function_messages,
format_to_openai_functions,
)
from langchain.agents.format_scratchpad.xml import format_xml
__all__ = [
"format_xml",
"format_to_openai_function_messages",
"format_to_openai_functions",
"format_log_to_str",
"format_log_to_messages",

View File

@ -49,14 +49,17 @@ def _create_function_message(
)
def format_to_openai_functions(
def format_to_openai_function_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
@ -64,3 +67,7 @@ def format_to_openai_functions(
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
# Backwards compatibility
format_to_openai_functions = format_to_openai_function_messages

View File

@ -0,0 +1,57 @@
import json
from typing import List, Sequence, Tuple
from langchain.agents.output_parsers.openai_tools import OpenAIToolAgentAction
from langchain.schema.agent import AgentAction
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ToolMessage,
)
def _create_tool_message(
agent_action: OpenAIToolAgentAction, observation: str
) -> ToolMessage:
"""Convert agent action and observation into a function message.
Args:
agent_action: the tool invocation request from the agent
observation: the result of the tool invocation
Returns:
FunctionMessage that corresponds to the original tool invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return ToolMessage(
tool_call_id=agent_action.tool_call_id,
content=content,
)
def format_to_openai_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for agent_action, observation in intermediate_steps:
if isinstance(agent_action, OpenAIToolAgentAction):
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
]
messages.extend([new for new in new_messages if new not in messages])
else:
messages.append(AIMessage(content=agent_action.log))
return messages

View File

@ -2,7 +2,7 @@
from typing import Any, Dict, List
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
format_to_openai_function_messages,
)
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.language_model import BaseLanguageModel
@ -52,7 +52,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
"""Save context from this conversation to buffer. Pruned."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str)
steps = format_to_openai_functions(outputs[self.intermediate_steps_key])
steps = format_to_openai_function_messages(outputs[self.intermediate_steps_key])
for msg in steps:
self.chat_memory.add_message(msg)
self.chat_memory.add_ai_message(output_str)

View File

@ -3,7 +3,7 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
from langchain.agents import BaseSingleActionAgent
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
format_to_openai_function_messages,
)
from langchain.agents.output_parsers.openai_functions import (
OpenAIFunctionsAgentOutputParser,
@ -93,7 +93,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = format_to_openai_functions(intermediate_steps)
agent_scratchpad = format_to_openai_function_messages(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
@ -132,7 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = format_to_openai_functions(intermediate_steps)
agent_scratchpad = format_to_openai_function_messages(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}

View File

@ -5,7 +5,7 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
from langchain.agents import BaseMultiActionAgent
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
format_to_openai_function_messages,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
@ -208,7 +208,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = format_to_openai_functions(intermediate_steps)
agent_scratchpad = format_to_openai_function_messages(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
@ -237,7 +237,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = format_to_openai_functions(intermediate_steps)
agent_scratchpad = format_to_openai_function_messages(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}

View File

@ -0,0 +1,106 @@
import asyncio
import json
from json import JSONDecodeError
from typing import List, Union
from langchain.agents.agent import MultiActionAgentOutputParser
from langchain.schema import (
AgentAction,
AgentFinish,
OutputParserException,
)
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.messages import (
AIMessage,
BaseMessage,
)
from langchain.schema.output import ChatGeneration, Generation
class OpenAIToolAgentAction(AgentActionMessageLog):
tool_call_id: str
"""Tool call that this message is responding to."""
def parse_ai_message_to_openai_tool_action(
message: BaseMessage
) -> Union[List[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")
if not message.additional_kwargs.get("tool_calls"):
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)
actions: List = []
for tool_call in message.additional_kwargs["tool_calls"]:
function = tool_call["function"]
function_name = function["name"]
try:
_tool_input = json.loads(function["arguments"])
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tool input: {function} because "
f"the `arguments` is not valid JSON."
)
# HACK HACK HACK:
# The code that encodes tool input into Open AI uses a special variable
# name called `__arg1` to handle old style tools that do not expose a
# schema and expect a single string argument as an input.
# We unpack the argument here if it exists.
# Open AI does not support passing in a JSON array as an argument.
if "__arg1" in _tool_input:
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
actions.append(
OpenAIToolAgentAction(
tool=function_name,
tool_input=tool_input,
log=log,
message_log=[message],
tool_call_id=tool_call["id"],
)
)
return actions
class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
"""Parses a message into agent actions/finish.
Is meant to be used with OpenAI models, as it relies on the specific
tool_calls parameter from OpenAI to convey what tools to use.
If a tool_calls parameter is passed, then that is used to get
the tool names and tool inputs.
If one is not passed, then the AIMessage is assumed to be the final output.
"""
@property
def _type(self) -> str:
return "openai-tools-agent-output-parser"
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return parse_ai_message_to_openai_tool_action(message)
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")

View File

@ -0,0 +1,3 @@
from langchain.chains.openai_tools.extraction import create_extraction_chain_pydantic
__all__ = ["create_extraction_chain_pydantic"]

View File

@ -0,0 +1,30 @@
from typing import List, Type, Union
from langchain.output_parsers import PydanticToolsParser
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.runnable import Runnable
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned \
in the following passage together with their properties.
If a property is not present and is not required in the function parameters, do not include it in the output.""" # noqa: E501
def create_extraction_chain_pydantic(
pydantic_schemas: Union[List[Type[BaseModel]], Type[BaseModel]],
llm: BaseLanguageModel,
system_message: str = _EXTRACTION_TEMPLATE,
) -> Runnable:
if not isinstance(pydantic_schemas, list):
pydantic_schemas = [pydantic_schemas]
prompt = ChatPromptTemplate.from_messages(
[("system", system_message), ("user", "{input}")]
)
functions = [convert_pydantic_to_openai_function(p) for p in pydantic_schemas]
tools = [{"type": "function", "function": d} for d in functions]
model = llm.bind(tools=tools)
chain = prompt | model | PydanticToolsParser(tools=pydantic_schemas)
return chain

View File

@ -42,6 +42,7 @@ from langchain.schema.messages import (
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import Runnable
@ -115,15 +116,14 @@ def _convert_delta_to_message_chunk(
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
additional_kwargs: Dict = {}
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
if (
"name" in additional_kwargs["function_call"]
and additional_kwargs["function_call"]["name"] is None
):
additional_kwargs["function_call"]["name"] = ""
else:
additional_kwargs = {}
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
@ -133,6 +133,8 @@ def _convert_delta_to_message_chunk(
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
@ -630,7 +632,6 @@ class ChatOpenAI(BaseChatModel):
from langchain.chains.openai_functions.base import convert_to_openai_function
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
function_call_ = None
if function_call is not None:
if len(formatted_functions) != 1:
raise ValueError(

View File

@ -23,6 +23,11 @@ from langchain.output_parsers.list import (
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
JsonOutputToolsParser,
PydanticToolsParser,
)
from langchain.output_parsers.pydantic import PydanticOutputParser
from langchain.output_parsers.rail_parser import GuardrailsOutputParser
from langchain.output_parsers.regex import RegexParser
@ -50,4 +55,7 @@ __all__ = [
"RetryWithErrorOutputParser",
"StructuredOutputParser",
"XMLOutputParser",
"JsonOutputToolsParser",
"PydanticToolsParser",
"JsonOutputKeyToolsParser",
]

View File

@ -0,0 +1,64 @@
import copy
import json
from typing import Any, List, Type
from langchain.pydantic_v1 import BaseModel
from langchain.schema import (
ChatGeneration,
Generation,
OutputParserException,
)
from langchain.schema.output_parser import (
BaseGenerationOutputParser,
)
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
"""Parse tools from OpenAI response."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
return []
final_tools = []
for tool_call in tool_calls:
if "function" not in tool_call:
pass
function_args = tool_call["function"]["arguments"]
final_tools.append(
{
"type": tool_call["function"]["name"],
"args": json.loads(function_args),
}
)
return final_tools
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
key_name: str
"""The type of tools to return."""
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
results = super().parse_result(result)
return [res["args"] for res in results if results["type"] == self.key_name]
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""
tools: List[Type[BaseModel]]
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
results = super().parse_result(result)
name_dict = {tool.__name__: tool for tool in self.tools}
return [name_dict[res["type"]](**res["args"]) for res in results]

View File

@ -288,6 +288,44 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
return super().__add__(other)
class ToolMessage(BaseMessage):
"""A Message for passing the result of executing a tool back to a model."""
tool_call_id: str
"""Tool call that this message is responding to."""
type: Literal["tool"] = "tool"
ToolMessage.update_forward_refs()
class ToolMessageChunk(ToolMessage, BaseMessageChunk):
"""A Tool Message chunk."""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ToolMessageChunk):
if self.tool_call_id != other.tool_call_id:
raise ValueError(
"Cannot concatenate ToolMessageChunks with different names."
)
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
return super().__add__(other)
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
@ -326,7 +364,9 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
return super().__add__(other)
AnyMessage = Union[AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage]
AnyMessage = Union[
AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage, ToolMessage
]
def _message_to_dict(message: BaseMessage) -> dict:
@ -357,6 +397,8 @@ def _message_from_dict(message: dict) -> BaseMessage:
return ChatMessage(**message["data"])
elif _type == "function":
return FunctionMessage(**message["data"])
elif _type == "tool":
return ToolMessage(**message["data"])
else:
raise ValueError(f"Got unexpected message type: {_type}")

View File

@ -9,6 +9,7 @@ from typing import List
from langchain.tools.base import BaseTool
from langchain.utils.openai_functions import (
FunctionDescription,
ToolDescription,
convert_pydantic_to_openai_function,
)
@ -67,3 +68,9 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"type": "object",
},
}
def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
"""Format tool into the OpenAI function API."""
function = format_tool_to_openai_function(tool)
return {"type": "function", "function": function}

View File

@ -1,4 +1,4 @@
from typing import Optional, Type, TypedDict
from typing import Literal, Optional, Type, TypedDict
from langchain.pydantic_v1 import BaseModel
from langchain.utils.json_schema import dereference_refs
@ -15,6 +15,13 @@ class FunctionDescription(TypedDict):
"""The parameters of the function."""
class ToolDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""
type: Literal["function"]
function: FunctionDescription
def convert_pydantic_to_openai_function(
model: Type[BaseModel],
*,
@ -29,3 +36,16 @@ def convert_pydantic_to_openai_function(
"description": description or schema["description"],
"parameters": schema,
}
def convert_pydantic_to_openai_tool(
model: Type[BaseModel],
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> ToolDescription:
"""Converts a Pydantic model to a function description for the OpenAI API."""
function = convert_pydantic_to_openai_function(
model, name=name, description=description
)
return {"type": "function", "function": function}

View File

@ -1,5 +1,5 @@
from langchain.agents.format_scratchpad.openai_functions import (
format_to_openai_functions,
format_to_openai_function_messages,
)
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.messages import AIMessage, FunctionMessage
@ -51,10 +51,10 @@ def test_calls_convert_agent_action_to_messages() -> None:
message3,
FunctionMessage(name="tool3", content="observation3"),
]
output = format_to_openai_functions(intermediate_steps)
output = format_to_openai_function_messages(intermediate_steps)
assert output == expected_messages
def test_handles_empty_input_list() -> None:
output = format_to_openai_functions([])
output = format_to_openai_function_messages([])
assert output == []

View File

@ -19,6 +19,9 @@ EXPECTED_ALL = [
"RetryWithErrorOutputParser",
"StructuredOutputParser",
"XMLOutputParser",
"JsonOutputToolsParser",
"PydanticToolsParser",
"JsonOutputKeyToolsParser",
]

View File

@ -1671,6 +1671,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
}),
'type': 'array',
@ -1797,6 +1800,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
}),
'title': 'Messages',
@ -1981,6 +1987,54 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'A Message for passing the result of executing a tool back to a model.',
'properties': dict({
'additional_kwargs': dict({
'title': 'Additional Kwargs',
'type': 'object',
}),
'content': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'items': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'object',
}),
]),
}),
'type': 'array',
}),
]),
'title': 'Content',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
}),
'type': dict({
'default': 'tool',
'enum': list([
'tool',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'content',
'tool_call_id',
]),
'title': 'ToolMessage',
'type': 'object',
}),
}),
'title': 'FakeListLLMInput',
})
@ -2015,6 +2069,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
}),
'type': 'array',
@ -2141,6 +2198,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
}),
'title': 'Messages',
@ -2325,6 +2385,54 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'A Message for passing the result of executing a tool back to a model.',
'properties': dict({
'additional_kwargs': dict({
'title': 'Additional Kwargs',
'type': 'object',
}),
'content': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'items': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'object',
}),
]),
}),
'type': 'array',
}),
]),
'title': 'Content',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
}),
'type': dict({
'default': 'tool',
'enum': list([
'tool',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'content',
'tool_call_id',
]),
'title': 'ToolMessage',
'type': 'object',
}),
}),
'title': 'FakeListChatModelInput',
})
@ -2347,6 +2455,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
'definitions': dict({
'AIMessage': dict({
@ -2587,6 +2698,54 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'A Message for passing the result of executing a tool back to a model.',
'properties': dict({
'additional_kwargs': dict({
'title': 'Additional Kwargs',
'type': 'object',
}),
'content': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'items': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'object',
}),
]),
}),
'type': 'array',
}),
]),
'title': 'Content',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
}),
'type': dict({
'default': 'tool',
'enum': list([
'tool',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'content',
'tool_call_id',
]),
'title': 'ToolMessage',
'type': 'object',
}),
}),
'title': 'FakeListChatModelOutput',
})
@ -2722,6 +2881,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
}),
'title': 'Messages',
@ -2906,6 +3068,54 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'A Message for passing the result of executing a tool back to a model.',
'properties': dict({
'additional_kwargs': dict({
'title': 'Additional Kwargs',
'type': 'object',
}),
'content': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'items': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'object',
}),
]),
}),
'type': 'array',
}),
]),
'title': 'Content',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
}),
'type': dict({
'default': 'tool',
'enum': list([
'tool',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'content',
'tool_call_id',
]),
'title': 'ToolMessage',
'type': 'object',
}),
}),
'title': 'ChatPromptTemplateOutput',
})
@ -3041,6 +3251,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
}),
'title': 'Messages',
@ -3225,6 +3438,54 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'A Message for passing the result of executing a tool back to a model.',
'properties': dict({
'additional_kwargs': dict({
'title': 'Additional Kwargs',
'type': 'object',
}),
'content': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'items': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'object',
}),
]),
}),
'type': 'array',
}),
]),
'title': 'Content',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
}),
'type': dict({
'default': 'tool',
'enum': list([
'tool',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'content',
'tool_call_id',
]),
'title': 'ToolMessage',
'type': 'object',
}),
}),
'title': 'PromptTemplateOutput',
})
@ -3352,6 +3613,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
}),
'title': 'Messages',
@ -3547,6 +3811,54 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'A Message for passing the result of executing a tool back to a model.',
'properties': dict({
'additional_kwargs': dict({
'title': 'Additional Kwargs',
'type': 'object',
}),
'content': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'items': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'object',
}),
]),
}),
'type': 'array',
}),
]),
'title': 'Content',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
}),
'type': dict({
'default': 'tool',
'enum': list([
'tool',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'content',
'tool_call_id',
]),
'title': 'ToolMessage',
'type': 'object',
}),
}),
'items': dict({
'$ref': '#/definitions/PromptTemplateOutput',
@ -3576,6 +3888,9 @@
dict({
'$ref': '#/definitions/FunctionMessage',
}),
dict({
'$ref': '#/definitions/ToolMessage',
}),
]),
'definitions': dict({
'AIMessage': dict({
@ -3816,6 +4131,54 @@
'title': 'SystemMessage',
'type': 'object',
}),
'ToolMessage': dict({
'description': 'A Message for passing the result of executing a tool back to a model.',
'properties': dict({
'additional_kwargs': dict({
'title': 'Additional Kwargs',
'type': 'object',
}),
'content': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'items': dict({
'anyOf': list([
dict({
'type': 'string',
}),
dict({
'type': 'object',
}),
]),
}),
'type': 'array',
}),
]),
'title': 'Content',
}),
'tool_call_id': dict({
'title': 'Tool Call Id',
'type': 'string',
}),
'type': dict({
'default': 'tool',
'enum': list([
'tool',
]),
'title': 'Type',
'type': 'string',
}),
}),
'required': list([
'content',
'tool_call_id',
]),
'title': 'ToolMessage',
'type': 'object',
}),
}),
'title': 'CommaSeparatedListOutputParserInput',
})

View File

@ -285,6 +285,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
{"$ref": "#/definitions/ChatMessage"},
{"$ref": "#/definitions/SystemMessage"},
{"$ref": "#/definitions/FunctionMessage"},
{"$ref": "#/definitions/ToolMessage"},
]
},
}
@ -362,7 +363,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
},
"ChatMessage": {
"title": "ChatMessage",
"description": "A Message that can be assigned an arbitrary speaker (i.e. role).", # noqa: E501
"description": "A Message that can be assigned an arbitrary speaker (i.e. role).", # noqa
"type": "object",
"properties": {
"content": {
@ -393,7 +394,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
},
"SystemMessage": {
"title": "SystemMessage",
"description": "A Message for priming AI behavior, usually passed in as the first of a sequence\nof input messages.", # noqa: E501
"description": "A Message for priming AI behavior, usually passed in as the first of a sequence\nof input messages.", # noqa
"type": "object",
"properties": {
"content": {
@ -423,7 +424,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
},
"FunctionMessage": {
"title": "FunctionMessage",
"description": "A Message for passing the result of executing a function back to a model.", # noqa: E501
"description": "A Message for passing the result of executing a function back to a model.", # noqa
"type": "object",
"properties": {
"content": {
@ -452,6 +453,37 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
},
"required": ["content", "name"],
},
"ToolMessage": {
"title": "ToolMessage",
"description": "A Message for passing the result of executing a tool back to a model.", # noqa
"type": "object",
"properties": {
"content": {
"title": "Content",
"anyOf": [
{"type": "string"},
{
"type": "array",
"items": {
"anyOf": [{"type": "string"}, {"type": "object"}]
},
},
],
},
"additional_kwargs": {
"title": "Additional Kwargs",
"type": "object",
},
"type": {
"title": "Type",
"default": "tool",
"enum": ["tool"],
"type": "string",
},
"tool_call_id": {"title": "Tool Call Id", "type": "string"},
},
"required": ["content", "tool_call_id"],
},
},
}
assert chat_prompt.output_schema.schema() == snapshot

View File

@ -1,7 +1,7 @@
from typing import List, Tuple
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
@ -50,7 +50,7 @@ agent = (
{
"input": lambda x: x["input"],
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
"agent_scratchpad": lambda x: format_to_openai_functions(
"agent_scratchpad": lambda x: format_to_openai_function_messages(
x["intermediate_steps"]
),
}