OpenAI runnable constructor (#12455)

This commit is contained in:
Bagatur
2023-10-29 13:40:30 -07:00
committed by GitHub
parent a830b809f3
commit 1815ea2fdb
3 changed files with 368 additions and 134 deletions

View File

@@ -24,6 +24,8 @@
"from langchain.chains.openai_functions import (\n",
" create_openai_fn_chain,\n",
" create_structured_output_chain,\n",
" create_openai_fn_runnable,\n",
" create_structured_output_runnable,\n",
")\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate"
@@ -35,9 +37,7 @@
"metadata": {},
"source": [
"## Getting structured outputs\n",
"We can take advantage of OpenAI functions to try and force the model to return a particular kind of structured output. We'll use `create_structured_output_chain` to create our chain, which takes the desired structured output either as a Pydantic class or as JsonSchema.\n",
"\n",
"See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_structured_output_chain.html)."
"We can take advantage of OpenAI functions to try and force the model to return a particular kind of structured output. We'll use `create_structured_output_runnable` to create our chain, which takes the desired structured output either as a Pydantic class or as JsonSchema."
]
},
{
@@ -73,21 +73,6 @@
"id": "b459a33e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new LLMChain chain...\u001B[0m\n",
"Prompt after formatting:\n",
"\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n",
"Human: Use the given format to extract information from the following input: Sally is 13\n",
"Human: Tip: Make sure to answer in the correct format\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
@@ -110,8 +95,8 @@
" ]\n",
")\n",
"\n",
"chain = create_structured_output_chain(Person, llm, prompt, verbose=True)\n",
"chain.run(\"Sally is 13\")"
"runnable = create_structured_output_runnable(Person, llm, prompt)\n",
"runnable.invoke({\"input\": \"Sally is 13\"})"
]
},
{
@@ -124,32 +109,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "4d8ea815",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new LLMChain chain...\u001B[0m\n",
"Prompt after formatting:\n",
"\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n",
"Human: Use the given format to extract information from the following input: Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally.\n",
"Human: Tip: Make sure to answer in the correct format\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
"People(people=[Person(name='Sally', age=13, fav_food=''), Person(name='Joey', age=12, fav_food='spinach'), Person(name='Caroline', age=23, fav_food='')])"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -164,9 +134,9 @@
" people: Sequence[Person] = Field(..., description=\"The people in the text\")\n",
"\n",
"\n",
"chain = create_structured_output_chain(People, llm, prompt, verbose=True)\n",
"chain.run(\n",
" \"Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally.\"\n",
"runnable = create_structured_output_runnable(People, llm, prompt)\n",
"runnable.invoke(\n",
" {\"input\": \"Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally.\"}\n",
")"
]
},
@@ -182,7 +152,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "3484415e",
"metadata": {},
"outputs": [],
@@ -206,9 +176,39 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "be9b76b3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'name': 'Sally', 'age': 13}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runnable = create_structured_output_runnable(json_schema, llm, prompt)\n",
"runnable.invoke({\"input\": \"Sally is 13\"})"
]
},
{
"cell_type": "markdown",
"id": "5f38ca2d-eb65-4836-9a21-9eaaa8c6c47c",
"metadata": {},
"source": [
"### [Legacy] LLMChain-based approach"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4cf8d9b8-043b-414d-81e5-1a53c4881845",
"metadata": {},
"outputs": [
{
"name": "stdout",
@@ -216,19 +216,19 @@
"text": [
"\n",
"\n",
"\u001B[1m> Entering new LLMChain chain...\u001B[0m\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n",
"Human: Use the given format to extract information from the following input: Sally is 13\n",
"Human: Tip: Make sure to answer in the correct format\u001B[0m\n",
"Human: Tip: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'name': 'Sally', 'age': 13}"
"Person(name='Sally', age=13, fav_food='Unknown')"
]
},
"execution_count": 7,
@@ -237,7 +237,7 @@
}
],
"source": [
"chain = create_structured_output_chain(json_schema, llm, prompt, verbose=True)\n",
"chain = create_structured_output_chain(Person, llm, prompt, verbose=True)\n",
"chain.run(\"Sally is 13\")"
]
},
@@ -247,14 +247,12 @@
"metadata": {},
"source": [
"## Creating a generic OpenAI functions chain\n",
"To create a generic OpenAI functions chain, we can use the `create_openai_fn_chain` method. This is the same as `create_structured_output_chain` except that instead of taking a single output schema, it takes a sequence of function definitions.\n",
"To create a generic OpenAI functions chain, we can use the `create_openai_fn_runnable` method. This is the same as `create_structured_output_runnable` except that instead of taking a single output schema, it takes a sequence of function definitions.\n",
"\n",
"Functions can be passed in as:\n",
"- dicts conforming to OpenAI functions spec,\n",
"- Pydantic classes, in which case they should have docstring descriptions of the function they represent and descriptions for each of the parameters,\n",
"- Python functions, in which case they should have docstring descriptions of the function and args, along with type hints.\n",
"\n",
"See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_openai_fn_chain.html)."
"- Python functions, in which case they should have docstring descriptions of the function and args, along with type hints."
]
},
{
@@ -267,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "17f52508",
"metadata": {},
"outputs": [],
@@ -290,37 +288,13 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 13,
"id": "a4658ad8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new LLMChain chain...\u001B[0m\n",
"Prompt after formatting:\n",
"\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for recording entities.\n",
"Human: Make calls to the relevant function to record the entities in the following input: Harry was a chubby brown beagle who loved chicken\n",
"Human: Tip: Make sure to answer in the correct format\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
"RecordDog(name='Harry', color='brown', fav_food='chicken')"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from langchain.chains.openai_functions import convert_to_openai_function, get_openai_output_parser\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", \"You are a world class algorithm for recording entities.\"),\n",
@@ -329,8 +303,63 @@
" ]\n",
")\n",
"\n",
"chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt, verbose=True)\n",
"chain.run(\"Harry was a chubby brown beagle who loved chicken\")"
"openai_functions = [convert_to_openai_function(f) for f in (RecordPerson, RecordDog)]\n",
"llm_kwargs = {\"functions\": openai_functions}\n",
"if len(openai_functions) == 1:\n",
" llm_kwargs[\"function_call\"] = {\"name\": openai_functions[0][\"name\"]}\n",
"output_parser = get_openai_output_parser((RecordPerson, RecordDog))\n",
"runnable = prompt | llm.bind(**llm_kwargs) | output_parser\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a32148a2-8495-4a2b-942a-d605b131bf69",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RecordDog(name='Harry', color='brown', fav_food='chicken')"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runnable.invoke({\"input\": \"Harry was a chubby brown beagle who loved chicken\"})"
]
},
{
"cell_type": "markdown",
"id": "b57b2ca4-6519-4f7e-9b62-9ce14aad914f",
"metadata": {},
"source": [
"For convenience we can use the `create_openai_fn_runnable` method to help build our Runnable"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "88538970-91b3-4eea-9c2b-47210713492a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RecordDog(name='Harry', color='brown', fav_food='chicken')"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runnable = create_openai_fn_runnable([RecordPerson, RecordDog], llm, prompt)\n",
"runnable.invoke({\"input\": \"Harry was a chubby brown beagle who loved chicken\"})"
]
},
{
@@ -346,32 +375,17 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 16,
"id": "95ac5825",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new LLMChain chain...\u001B[0m\n",
"Prompt after formatting:\n",
"\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for recording entities.\n",
"Human: Make calls to the relevant function to record the entities in the following input: The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\n",
"Human: Tip: Make sure to answer in the correct format\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
"{'name': 'Tommy', 'age': 12, 'fav_food': {'food': 'apple pie'}}"
]
},
"execution_count": 11,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -397,9 +411,9 @@
" return f\"Recording person {name} of age {age} with favorite food {fav_food.food}!\"\n",
"\n",
"\n",
"chain = create_openai_fn_chain([record_person], llm, prompt, verbose=True)\n",
"chain.run(\n",
" \"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\"\n",
"runnable = create_openai_fn_runnable([record_person], llm, prompt)\n",
"runnable.invoke(\n",
" {\"input\": \"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\"}\n",
")"
]
},
@@ -416,25 +430,10 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 17,
"id": "8b0d11de",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001B[1m> Entering new LLMChain chain...\u001B[0m\n",
"Prompt after formatting:\n",
"\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for recording entities.\n",
"Human: Make calls to the relevant function to record the entities in the following input: I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\n",
"Human: Tip: Make sure to answer in the correct format\u001B[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
@@ -442,7 +441,7 @@
" 'arguments': {'name': 'Henry', 'color': 'brown', 'fav_food': {'food': None}}}"
]
},
"execution_count": 12,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@@ -459,12 +458,57 @@
" return f\"Recording dog {name} of color {color} with favorite food {fav_food}!\"\n",
"\n",
"\n",
"chain = create_openai_fn_chain([record_person, record_dog], llm, prompt, verbose=True)\n",
"chain.run(\n",
" \"I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\"\n",
"runnable = create_openai_fn_runnable([record_person, record_dog], llm, prompt)\n",
"runnable.invoke(\n",
" {\"input\": \"I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\"}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c81e301d-3125-4b25-8a74-86ba9562952c",
"metadata": {},
"source": [
"## [Legacy] LLMChain-based approach"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "32711985-8dac-448a-ad65-cd3dd5e45fbe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for recording entities.\n",
"Human: Make calls to the relevant function to record the entities in the following input: Harry was a chubby brown beagle who loved chicken\n",
"Human: Tip: Make sure to answer in the correct format\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"RecordDog(name='Harry', color='brown', fav_food='chicken')"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt, verbose=True)\n",
"chain.run(\"Harry was a chubby brown beagle who loved chicken\")"
]
},
{
"cell_type": "markdown",
"id": "5f93686b",