From 1815ea2fdbd020da9c5b8da114a6d7eae05f5b5f Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Sun, 29 Oct 2023 13:40:30 -0700 Subject: [PATCH] OpenAI runnable constructor (#12455) --- .../chains/how_to/openai_functions.ipynb | 290 ++++++++++-------- .../chains/openai_functions/__init__.py | 8 + .../langchain/chains/openai_functions/base.py | 204 +++++++++++- 3 files changed, 368 insertions(+), 134 deletions(-) diff --git a/docs/docs/modules/chains/how_to/openai_functions.ipynb b/docs/docs/modules/chains/how_to/openai_functions.ipynb index 3aa0f988159..675457ec419 100644 --- a/docs/docs/modules/chains/how_to/openai_functions.ipynb +++ b/docs/docs/modules/chains/how_to/openai_functions.ipynb @@ -24,6 +24,8 @@ "from langchain.chains.openai_functions import (\n", " create_openai_fn_chain,\n", " create_structured_output_chain,\n", + " create_openai_fn_runnable,\n", + " create_structured_output_runnable,\n", ")\n", "from langchain.chat_models import ChatOpenAI\n", "from langchain.prompts import ChatPromptTemplate" @@ -35,9 +37,7 @@ "metadata": {}, "source": [ "## Getting structured outputs\n", - "We can take advantage of OpenAI functions to try and force the model to return a particular kind of structured output. We'll use `create_structured_output_chain` to create our chain, which takes the desired structured output either as a Pydantic class or as JsonSchema.\n", - "\n", - "See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_structured_output_chain.html)." + "We can take advantage of OpenAI functions to try and force the model to return a particular kind of structured output. We'll use `create_structured_output_runnable` to create our chain, which takes the desired structured output either as a Pydantic class or as JsonSchema." ] }, { @@ -73,21 +73,6 @@ "id": "b459a33e", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001B[1m> Entering new LLMChain chain...\u001B[0m\n", - "Prompt after formatting:\n", - "\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n", - "Human: Use the given format to extract information from the following input: Sally is 13\n", - "Human: Tip: Make sure to answer in the correct format\u001B[0m\n", - "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" - ] - }, { "data": { "text/plain": [ @@ -110,8 +95,8 @@ " ]\n", ")\n", "\n", - "chain = create_structured_output_chain(Person, llm, prompt, verbose=True)\n", - "chain.run(\"Sally is 13\")" + "runnable = create_structured_output_runnable(Person, llm, prompt)\n", + "runnable.invoke({\"input\": \"Sally is 13\"})" ] }, { @@ -124,32 +109,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "4d8ea815", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001B[1m> Entering new LLMChain chain...\u001B[0m\n", - "Prompt after formatting:\n", - "\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n", - "Human: Use the given format to extract information from the following input: Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally.\n", - "Human: Tip: Make sure to answer in the correct format\u001B[0m\n", - "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" - ] - }, { "data": { "text/plain": [ "People(people=[Person(name='Sally', age=13, fav_food=''), Person(name='Joey', age=12, fav_food='spinach'), Person(name='Caroline', age=23, fav_food='')])" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -164,9 +134,9 @@ " people: Sequence[Person] = Field(..., description=\"The people in the text\")\n", "\n", "\n", - "chain = create_structured_output_chain(People, llm, prompt, verbose=True)\n", - "chain.run(\n", - " \"Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally.\"\n", + "runnable = create_structured_output_runnable(People, llm, prompt)\n", + "runnable.invoke(\n", + " {\"input\": \"Sally is 13, Joey just turned 12 and loves spinach. Caroline is 10 years older than Sally.\"}\n", ")" ] }, @@ -182,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "3484415e", "metadata": {}, "outputs": [], @@ -206,9 +176,39 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "be9b76b3", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'name': 'Sally', 'age': 13}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "runnable = create_structured_output_runnable(json_schema, llm, prompt)\n", + "runnable.invoke({\"input\": \"Sally is 13\"})" + ] + }, + { + "cell_type": "markdown", + "id": "5f38ca2d-eb65-4836-9a21-9eaaa8c6c47c", + "metadata": {}, + "source": [ + "### [Legacy] LLMChain-based approach" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4cf8d9b8-043b-414d-81e5-1a53c4881845", + "metadata": {}, "outputs": [ { "name": "stdout", @@ -216,19 +216,19 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new LLMChain chain...\u001B[0m\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", - "\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n", + "\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for extracting information in structured formats.\n", "Human: Use the given format to extract information from the following input: Sally is 13\n", - "Human: Tip: Make sure to answer in the correct format\u001B[0m\n", + "Human: Tip: Make sure to answer in the correct format\u001b[0m\n", "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "{'name': 'Sally', 'age': 13}" + "Person(name='Sally', age=13, fav_food='Unknown')" ] }, "execution_count": 7, @@ -237,7 +237,7 @@ } ], "source": [ - "chain = create_structured_output_chain(json_schema, llm, prompt, verbose=True)\n", + "chain = create_structured_output_chain(Person, llm, prompt, verbose=True)\n", "chain.run(\"Sally is 13\")" ] }, @@ -247,14 +247,12 @@ "metadata": {}, "source": [ "## Creating a generic OpenAI functions chain\n", - "To create a generic OpenAI functions chain, we can use the `create_openai_fn_chain` method. This is the same as `create_structured_output_chain` except that instead of taking a single output schema, it takes a sequence of function definitions.\n", + "To create a generic OpenAI functions chain, we can use the `create_openai_fn_runnable` method. This is the same as `create_structured_output_runnable` except that instead of taking a single output schema, it takes a sequence of function definitions.\n", "\n", "Functions can be passed in as:\n", "- dicts conforming to OpenAI functions spec,\n", "- Pydantic classes, in which case they should have docstring descriptions of the function they represent and descriptions for each of the parameters,\n", - "- Python functions, in which case they should have docstring descriptions of the function and args, along with type hints.\n", - "\n", - "See here for relevant [reference docs](https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_functions.base.create_openai_fn_chain.html)." + "- Python functions, in which case they should have docstring descriptions of the function and args, along with type hints." ] }, { @@ -267,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "17f52508", "metadata": {}, "outputs": [], @@ -290,37 +288,13 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "id": "a4658ad8", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001B[1m> Entering new LLMChain chain...\u001B[0m\n", - "Prompt after formatting:\n", - "\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for recording entities.\n", - "Human: Make calls to the relevant function to record the entities in the following input: Harry was a chubby brown beagle who loved chicken\n", - "Human: Tip: Make sure to answer in the correct format\u001B[0m\n", - "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "RecordDog(name='Harry', color='brown', fav_food='chicken')" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ + "from langchain.chains.openai_functions import convert_to_openai_function, get_openai_output_parser\n", + "\n", "prompt = ChatPromptTemplate.from_messages(\n", " [\n", " (\"system\", \"You are a world class algorithm for recording entities.\"),\n", @@ -329,8 +303,63 @@ " ]\n", ")\n", "\n", - "chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt, verbose=True)\n", - "chain.run(\"Harry was a chubby brown beagle who loved chicken\")" + "openai_functions = [convert_to_openai_function(f) for f in (RecordPerson, RecordDog)]\n", + "llm_kwargs = {\"functions\": openai_functions}\n", + "if len(openai_functions) == 1:\n", + " llm_kwargs[\"function_call\"] = {\"name\": openai_functions[0][\"name\"]}\n", + "output_parser = get_openai_output_parser((RecordPerson, RecordDog))\n", + "runnable = prompt | llm.bind(**llm_kwargs) | output_parser\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a32148a2-8495-4a2b-942a-d605b131bf69", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RecordDog(name='Harry', color='brown', fav_food='chicken')" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "runnable.invoke({\"input\": \"Harry was a chubby brown beagle who loved chicken\"})" + ] + }, + { + "cell_type": "markdown", + "id": "b57b2ca4-6519-4f7e-9b62-9ce14aad914f", + "metadata": {}, + "source": [ + "For convenience we can use the `create_openai_fn_runnable` method to help build our Runnable" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "88538970-91b3-4eea-9c2b-47210713492a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RecordDog(name='Harry', color='brown', fav_food='chicken')" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "runnable = create_openai_fn_runnable([RecordPerson, RecordDog], llm, prompt)\n", + "runnable.invoke({\"input\": \"Harry was a chubby brown beagle who loved chicken\"})" ] }, { @@ -346,32 +375,17 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "id": "95ac5825", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001B[1m> Entering new LLMChain chain...\u001B[0m\n", - "Prompt after formatting:\n", - "\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for recording entities.\n", - "Human: Make calls to the relevant function to record the entities in the following input: The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\n", - "Human: Tip: Make sure to answer in the correct format\u001B[0m\n", - "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" - ] - }, { "data": { "text/plain": [ "{'name': 'Tommy', 'age': 12, 'fav_food': {'food': 'apple pie'}}" ] }, - "execution_count": 11, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -397,9 +411,9 @@ " return f\"Recording person {name} of age {age} with favorite food {fav_food.food}!\"\n", "\n", "\n", - "chain = create_openai_fn_chain([record_person], llm, prompt, verbose=True)\n", - "chain.run(\n", - " \"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\"\n", + "runnable = create_openai_fn_runnable([record_person], llm, prompt)\n", + "runnable.invoke(\n", + " {\"input\": \"The most important thing to remember about Tommy, my 12 year old, is that he'll do anything for apple pie.\"}\n", ")" ] }, @@ -416,25 +430,10 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 17, "id": "8b0d11de", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001B[1m> Entering new LLMChain chain...\u001B[0m\n", - "Prompt after formatting:\n", - "\u001B[32;1m\u001B[1;3mSystem: You are a world class algorithm for recording entities.\n", - "Human: Make calls to the relevant function to record the entities in the following input: I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\n", - "Human: Tip: Make sure to answer in the correct format\u001B[0m\n", - "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" - ] - }, { "data": { "text/plain": [ @@ -442,7 +441,7 @@ " 'arguments': {'name': 'Henry', 'color': 'brown', 'fav_food': {'food': None}}}" ] }, - "execution_count": 12, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -459,12 +458,57 @@ " return f\"Recording dog {name} of color {color} with favorite food {fav_food}!\"\n", "\n", "\n", - "chain = create_openai_fn_chain([record_person, record_dog], llm, prompt, verbose=True)\n", - "chain.run(\n", - " \"I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\"\n", + "runnable = create_openai_fn_runnable([record_person, record_dog], llm, prompt)\n", + "runnable.invoke(\n", + " {\"input\": \"I can't find my dog Henry anywhere, he's a small brown beagle. Could you send a message about him?\"}\n", ")" ] }, + { + "cell_type": "markdown", + "id": "c81e301d-3125-4b25-8a74-86ba9562952c", + "metadata": {}, + "source": [ + "## [Legacy] LLMChain-based approach" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "32711985-8dac-448a-ad65-cd3dd5e45fbe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3mSystem: You are a world class algorithm for recording entities.\n", + "Human: Make calls to the relevant function to record the entities in the following input: Harry was a chubby brown beagle who loved chicken\n", + "Human: Tip: Make sure to answer in the correct format\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "RecordDog(name='Harry', color='brown', fav_food='chicken')" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain = create_openai_fn_chain([RecordPerson, RecordDog], llm, prompt, verbose=True)\n", + "chain.run(\"Harry was a chubby brown beagle who loved chicken\")" + ] + }, { "cell_type": "markdown", "id": "5f93686b", diff --git a/libs/langchain/langchain/chains/openai_functions/__init__.py b/libs/langchain/langchain/chains/openai_functions/__init__.py index 3ac797b8490..8be27606d7a 100644 --- a/libs/langchain/langchain/chains/openai_functions/__init__.py +++ b/libs/langchain/langchain/chains/openai_functions/__init__.py @@ -1,6 +1,10 @@ from langchain.chains.openai_functions.base import ( + convert_to_openai_function, create_openai_fn_chain, + create_openai_fn_runnable, create_structured_output_chain, + create_structured_output_runnable, + get_openai_output_parser, ) from langchain.chains.openai_functions.citation_fuzzy_match import ( create_citation_fuzzy_match_chain, @@ -19,6 +23,7 @@ from langchain.chains.openai_functions.tagging import ( ) __all__ = [ + "convert_to_openai_function", "create_tagging_chain", "create_tagging_chain_pydantic", "create_extraction_chain_pydantic", @@ -28,4 +33,7 @@ __all__ = [ "create_qa_with_sources_chain", "create_structured_output_chain", "create_openai_fn_chain", + "create_structured_output_runnable", + "create_openai_fn_runnable", + "get_openai_output_parser", ] diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index 2de37edf128..80c6d95565b 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -23,6 +23,8 @@ from langchain.output_parsers.openai_functions import ( from langchain.prompts import BasePromptTemplate from langchain.pydantic_v1 import BaseModel from langchain.schema import BaseLLMOutputParser +from langchain.schema.output_parser import BaseGenerationOutputParser, BaseOutputParser +from langchain.schema.runnable import Runnable from langchain.utils.openai_functions import convert_pydantic_to_openai_function PYTHON_TO_JSON_TYPES = { @@ -161,11 +163,23 @@ def convert_to_openai_function( ) -def _get_openai_output_parser( +def get_openai_output_parser( functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], - function_names: Sequence[str], -) -> BaseLLMOutputParser: - """Get the appropriate function output parser given the user functions.""" +) -> Union[BaseOutputParser, BaseGenerationOutputParser]: + """Get the appropriate function output parser given the user functions. + + Args: + functions: Sequence where element is a dictionary, a pydantic.BaseModel class, + or a Python function. If a dictionary is passed in, it is assumed to + already be a valid OpenAI function. + + Returns: + A PydanticOutputFunctionsParser if functions are Pydantic classes, otherwise + a JsonOutputFunctionsParser. If there's only one function and it is + not a Pydantic class, then the output parser will automatically extract + only the function arguments and not the function name. + """ + function_names = [convert_to_openai_function(f)["name"] for f in functions] if isinstance(functions[0], type) and issubclass(functions[0], BaseModel): if len(functions) > 1: pydantic_schema: Union[Dict, Type[BaseModel]] = { @@ -173,14 +187,183 @@ def _get_openai_output_parser( } else: pydantic_schema = functions[0] - output_parser: BaseLLMOutputParser = PydanticOutputFunctionsParser( - pydantic_schema=pydantic_schema - ) + output_parser: Union[ + BaseOutputParser, BaseGenerationOutputParser + ] = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) else: output_parser = JsonOutputFunctionsParser(args_only=len(functions) <= 1) return output_parser +def create_openai_fn_runnable( + functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], + llm: Runnable, + prompt: BasePromptTemplate, + *, + output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, + **kwargs: Any, +) -> Runnable: + """Create a runnable sequence that uses OpenAI functions. + + Args: + functions: A sequence of either dictionaries, pydantic.BaseModels classes, or + Python functions. If dictionaries are passed in, they are assumed to + already be a valid OpenAI functions. If only a single + function is passed in, then it will be enforced that the model use that + function. pydantic.BaseModels and Python functions should have docstrings + describing what the function does. For best results, pydantic.BaseModels + should have descriptions of the parameters and Python functions should have + Google Python style args descriptions in the docstring. Additionally, + Python functions should only use primitive types (str, int, float, bool) or + pydantic.BaseModels for arguments. + llm: Language model to use, assumed to support the OpenAI function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. If multiple functions are + passed in and they are not pydantic.BaseModels, the chain output will + include both the name of the function that was returned and the arguments + to pass to the function. + + Returns: + A runnable sequence that will pass in the given functions to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.openai_functions import create_openai_fn_chain + from langchain.chat_models import ChatOpenAI + from langchain.prompts import ChatPromptTemplate + from langchain.pydantic_v1 import BaseModel, Field + + + class RecordPerson(BaseModel): + \"\"\"Record some identifying information about a person.\"\"\" + + name: str = Field(..., description="The person's name") + age: int = Field(..., description="The person's age") + fav_food: Optional[str] = Field(None, description="The person's favorite food") + + + class RecordDog(BaseModel): + \"\"\"Record some identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + + llm = ChatOpenAI(model="gpt-4", temperature=0) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a world class algorithm for recording entities."), + ("human", "Make calls to the relevant function to record the entities in the following input: {input}"), + ("human", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_openai_fn_runnable([RecordPerson, RecordDog], llm, prompt) + chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) + # -> RecordDog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if not functions: + raise ValueError("Need to pass in at least one function. Received zero.") + openai_functions = [convert_to_openai_function(f) for f in functions] + llm_kwargs: Dict[str, Any] = {"functions": openai_functions, **kwargs} + if len(openai_functions) == 1: + llm_kwargs["function_call"] = {"name": openai_functions[0]["name"]} + output_parser = output_parser or get_openai_output_parser(functions) + return prompt | llm.bind(**llm_kwargs) | output_parser + + +def create_structured_output_runnable( + output_schema: Union[Dict[str, Any], Type[BaseModel]], + llm: Runnable, + prompt: BasePromptTemplate, + *, + output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, + **kwargs: Any, +) -> Runnable: + """Create a runnable that uses an OpenAI function to get a structured output. + + Args: + output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary + is passed in, it's assumed to already be a valid JsonSchema. + For best results, pydantic.BaseModels should have docstrings describing what + the schema represents and descriptions for the parameters. + llm: Language model to use, assumed to support the OpenAI function-calling API. + prompt: BasePromptTemplate to pass to the model. + output_parser: BaseLLMOutputParser to use for parsing model outputs. By default + will be inferred from the function types. If pydantic.BaseModels are passed + in, then the OutputParser will try to parse outputs using those. Otherwise + model outputs will simply be parsed as JSON. + + Returns: + A runnable sequence that will pass the given function to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain.chains.openai_functions import create_structured_output_chain + from langchain.chat_models import ChatOpenAI + from langchain.prompts import ChatPromptTemplate + from langchain.pydantic_v1 import BaseModel, Field + + class Dog(BaseModel): + \"\"\"Identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + llm = ChatOpenAI(model="gpt-3.5-turbo-0613", temperature=0) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a world class algorithm for extracting information in structured formats."), + ("human", "Use the given format to extract information from the following input: {input}"), + ("human", "Tip: Make sure to answer in the correct format"), + ] + ) + chain = create_structured_output_chain(Dog, llm, prompt) + chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) + # -> Dog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if isinstance(output_schema, dict): + function: Any = { + "name": "output_formatter", + "description": ( + "Output formatter. Should always be used to format your response to the" + " user." + ), + "parameters": output_schema, + } + else: + + class _OutputFormatter(BaseModel): + """Output formatter. Should always be used to format your response to the user.""" # noqa: E501 + + output: output_schema # type: ignore + + function = _OutputFormatter + output_parser = output_parser or PydanticAttrOutputFunctionsParser( + pydantic_schema=_OutputFormatter, attr_name="output" + ) + return create_openai_fn_runnable( + [function], + llm, + prompt, + output_parser=output_parser, + **kwargs, + ) + + +""" --- Legacy --- """ + + def create_openai_fn_chain( functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], llm: BaseLanguageModel, @@ -190,7 +373,7 @@ def create_openai_fn_chain( output_parser: Optional[BaseLLMOutputParser] = None, **kwargs: Any, ) -> LLMChain: - """Create an LLM chain that uses OpenAI functions. + """[Legacy] Create an LLM chain that uses OpenAI functions. Args: functions: A sequence of either dictionaries, pydantic.BaseModels classes, or @@ -260,8 +443,7 @@ def create_openai_fn_chain( if not functions: raise ValueError("Need to pass in at least one function. Received zero.") openai_functions = [convert_to_openai_function(f) for f in functions] - fn_names = [oai_fn["name"] for oai_fn in openai_functions] - output_parser = output_parser or _get_openai_output_parser(functions, fn_names) + output_parser = output_parser or get_openai_output_parser(functions) llm_kwargs: Dict[str, Any] = { "functions": openai_functions, } @@ -287,7 +469,7 @@ def create_structured_output_chain( output_parser: Optional[BaseLLMOutputParser] = None, **kwargs: Any, ) -> LLMChain: - """Create an LLMChain that uses an OpenAI function to get a structured output. + """[Legacy] Create an LLMChain that uses an OpenAI function to get a structured output. Args: output_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary