From 78403a37461d9ef72bcca6976197d995192447f3 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:21:06 -0700 Subject: [PATCH] core[patch], openai[patch]: enable strict tool calling (#25111) Introduced https://openai.com/index/introducing-structured-outputs-in-the-api/ --- docs/docs/integrations/chat/openai.ipynb | 64 +- .../langchain_core/utils/function_calling.py | 36 +- .../tests/unit_tests/chat_models/test_base.py | 6 +- .../load/__snapshots__/test_dump.ambr | 814 ------------------ .../langchain_openai/chat_models/base.py | 99 ++- .../chat_models/test_base.py | 81 ++ 6 files changed, 241 insertions(+), 859 deletions(-) diff --git a/docs/docs/integrations/chat/openai.ipynb b/docs/docs/integrations/chat/openai.ipynb index 4d23b1067fd..8da3239c506 100644 --- a/docs/docs/integrations/chat/openai.ipynb +++ b/docs/docs/integrations/chat/openai.ipynb @@ -56,23 +56,16 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8", "metadata": {}, - "outputs": [ - { - "name": "stdin", - "output_type": "stream", - "text": [ - "Enter your OpenAI API key: ········\n" - ] - } - ], + "outputs": [], "source": [ "import getpass\n", "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter your OpenAI API key: \")" + "if not os.environ.get(\"OPENAI_API_KEY\"):\n", + " os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter your OpenAI API key: \")" ] }, { @@ -126,7 +119,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "522686de", "metadata": { "tags": [] @@ -281,12 +274,12 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec", "metadata": {}, "outputs": [], "source": [ - "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "from pydantic import BaseModel, Field\n", "\n", "\n", "class GetWeather(BaseModel):\n", @@ -322,6 +315,45 @@ "ai_msg" ] }, + { + "cell_type": "markdown", + "id": "67b0f63d-15e6-45e0-9e86-2852ddcff54f", + "metadata": {}, + "source": [ + "### ``strict=True``\n", + "\n", + ".. info Requires ``langchain-openai==0.1.21rc1``\n", + "\n", + " As of Aug 6, 2024, OpenAI supports a `strict` argument when calling tools that will enforce that the tool argument schema is respected by the model. See more here: https://platform.openai.com/docs/guides/function-calling\n", + "\n", + " **Note**: If ``strict=True`` the tool definition will also be validated, and a subset of JSON schema are accepted. Crucially, schema cannot have optional args (those with default values). Read the full docs on what types of schema are supported here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dc8ac4f1-4039-4392-90c1-2d8331cd6910", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_VYEfpPDh3npMQ95J9EWmWvSn', 'function': {'arguments': '{\"location\":\"San Francisco, CA\"}', 'name': 'GetWeather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 68, 'total_tokens': 85}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_3aa7262c27', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-a4c6749b-adbb-45c7-8b17-8d6835d5c443-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': 'call_VYEfpPDh3npMQ95J9EWmWvSn', 'type': 'tool_call'}], usage_metadata={'input_tokens': 68, 'output_tokens': 17, 'total_tokens': 85})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_with_tools = llm.bind_tools([GetWeather], strict=True)\n", + "ai_msg = llm_with_tools.invoke(\n", + " \"what is the weather like in San Francisco\",\n", + ")\n", + "ai_msg" + ] + }, { "cell_type": "markdown", "id": "768d1ae4-4b1a-48eb-a329-c8d5051067a3", @@ -412,9 +444,9 @@ ], "metadata": { "kernelspec": { - "display_name": "poetry-venv-2", + "display_name": "poetry-venv-311", "language": "python", - "name": "poetry-venv-2" + "name": "poetry-venv-311" }, "language_info": { "codemirror_mode": { diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 10c2a2609df..b726aa00192 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -322,6 +322,8 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription: def convert_to_openai_function( function: Union[Dict[str, Any], Type, Callable, BaseTool], + *, + strict: Optional[bool] = None, ) -> Dict[str, Any]: """Convert a raw function/class to an OpenAI function. @@ -330,6 +332,9 @@ def convert_to_openai_function( Tool object, or a Python function. If a dictionary is passed in, it is assumed to already be a valid OpenAI function or a JSON schema with top-level 'title' and 'description' keys specified. + strict: If True, model output is guaranteed to exactly match the JSON Schema + provided in the function definition. If None, ``strict`` argument will not + be included in function definition. Returns: A dict version of the passed in function which is compatible with the OpenAI @@ -344,25 +349,27 @@ def convert_to_openai_function( if isinstance(function, dict) and all( k in function for k in ("name", "description", "parameters") ): - return function + oai_function = function # a JSON schema with title and description elif isinstance(function, dict) and all( k in function for k in ("title", "description", "properties") ): function = function.copy() - return { + oai_function = { "name": function.pop("title"), "description": function.pop("description"), "parameters": function, } elif isinstance(function, type) and is_basemodel_subclass(function): - return cast(Dict, convert_pydantic_to_openai_function(function)) + oai_function = cast(Dict, convert_pydantic_to_openai_function(function)) elif is_typeddict(function): - return cast(Dict, _convert_typed_dict_to_openai_function(cast(Type, function))) + oai_function = cast( + Dict, _convert_typed_dict_to_openai_function(cast(Type, function)) + ) elif isinstance(function, BaseTool): - return cast(Dict, format_tool_to_openai_function(function)) + oai_function = cast(Dict, format_tool_to_openai_function(function)) elif callable(function): - return cast(Dict, convert_python_function_to_openai_function(function)) + oai_function = cast(Dict, convert_python_function_to_openai_function(function)) else: raise ValueError( f"Unsupported function\n\n{function}\n\nFunctions must be passed in" @@ -371,9 +378,18 @@ def convert_to_openai_function( " 'title' and 'description' keys." ) + if strict is not None: + oai_function["strict"] = strict + # As of 08/06/24, OpenAI requires that additionalProperties be supplied and set + # to False if strict is True. + oai_function["parameters"]["additionalProperties"] = False + return oai_function + def convert_to_openai_tool( tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], + *, + strict: Optional[bool] = None, ) -> Dict[str, Any]: """Convert a raw function/class to an OpenAI tool. @@ -382,6 +398,9 @@ def convert_to_openai_tool( BaseTool. If a dictionary is passed in, it is assumed to already be a valid OpenAI tool, OpenAI function, or a JSON schema with top-level 'title' and 'description' keys specified. + strict: If True, model output is guaranteed to exactly match the JSON Schema + provided in the function definition. If None, ``strict`` argument will not + be included in tool definition. Returns: A dict version of the passed in tool which is compatible with the @@ -389,8 +408,9 @@ def convert_to_openai_tool( """ if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool: return tool - function = convert_to_openai_function(tool) - return {"type": "function", "function": function} + oai_function = convert_to_openai_function(tool, strict=strict) + oai_tool: Dict[str, Any] = {"type": "function", "function": oai_function} + return oai_tool def tool_example_to_messages( diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index c11469642ed..7fe06d85e10 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -113,7 +113,11 @@ def test_configurable() -> None: "tools": [ { "type": "function", - "function": {"name": "foo", "description": "foo", "parameters": {}}, + "function": { + "name": "foo", + "description": "foo", + "parameters": {}, + }, } ] }, diff --git a/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr b/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr index 2f98e081056..faae79d8209 100644 --- a/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr +++ b/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr @@ -59,817 +59,3 @@ # name: test_person_with_kwargs '{"lc":1,"type":"constructor","id":["tests","unit_tests","load","test_dump","Person"],"kwargs":{"secret":{"lc":1,"type":"secret","id":["SECRET"]},"you_can_see_me":"hello"}}' # --- -# name: test_serialize_llmchain - ''' - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "chains", - "llm", - "LLMChain" - ], - "kwargs": { - "prompt": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "kwargs": { - "input_variables": [ - "name" - ], - "template": "hello {name}!", - "template_format": "f-string" - }, - "name": "PromptTemplate", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "PromptInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "name": "PromptTemplate" - } - }, - { - "id": 2, - "type": "schema", - "data": "PromptTemplateOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "llm": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "llms", - "openai", - "OpenAI" - ], - "kwargs": { - "model_name": "davinci", - "temperature": 0.5, - "max_tokens": 256, - "top_p": 1, - "n": 1, - "best_of": 1, - "openai_api_key": { - "lc": 1, - "type": "secret", - "id": [ - "OPENAI_API_KEY" - ] - }, - "openai_proxy": "", - "batch_size": 20, - "max_retries": 2, - "disallowed_special": "all" - }, - "name": "OpenAI", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "OpenAIInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "llms", - "openai", - "OpenAI" - ], - "name": "OpenAI" - } - }, - { - "id": 2, - "type": "schema", - "data": "OpenAIOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "output_key": "text", - "output_parser": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "output_parser", - "StrOutputParser" - ], - "kwargs": {}, - "name": "StrOutputParser", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "StrOutputParserInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "schema", - "output_parser", - "StrOutputParser" - ], - "name": "StrOutputParser" - } - }, - { - "id": 2, - "type": "schema", - "data": "StrOutputParserOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "return_final_only": true - }, - "name": "LLMChain", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "ChainInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "chains", - "llm", - "LLMChain" - ], - "name": "LLMChain" - } - }, - { - "id": 2, - "type": "schema", - "data": "ChainOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - } - ''' -# --- -# name: test_serialize_llmchain_chat - ''' - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "chains", - "llm", - "LLMChain" - ], - "kwargs": { - "prompt": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "prompts", - "chat", - "ChatPromptTemplate" - ], - "kwargs": { - "input_variables": [ - "name" - ], - "messages": [ - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "prompts", - "chat", - "HumanMessagePromptTemplate" - ], - "kwargs": { - "prompt": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "kwargs": { - "input_variables": [ - "name" - ], - "template": "hello {name}!", - "template_format": "f-string" - }, - "name": "PromptTemplate", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "PromptInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "name": "PromptTemplate" - } - }, - { - "id": 2, - "type": "schema", - "data": "PromptTemplateOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - } - } - } - ] - }, - "name": "ChatPromptTemplate", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "PromptInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "prompts", - "chat", - "ChatPromptTemplate" - ], - "name": "ChatPromptTemplate" - } - }, - { - "id": 2, - "type": "schema", - "data": "ChatPromptTemplateOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "llm": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "chat_models", - "openai", - "ChatOpenAI" - ], - "kwargs": { - "model_name": "davinci", - "temperature": 0.5, - "openai_api_key": { - "lc": 1, - "type": "secret", - "id": [ - "OPENAI_API_KEY" - ] - }, - "openai_proxy": "", - "max_retries": 2, - "n": 1 - }, - "name": "ChatOpenAI", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "ChatOpenAIInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "chat_models", - "openai", - "ChatOpenAI" - ], - "name": "ChatOpenAI" - } - }, - { - "id": 2, - "type": "schema", - "data": "ChatOpenAIOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "output_key": "text", - "output_parser": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "output_parser", - "StrOutputParser" - ], - "kwargs": {}, - "name": "StrOutputParser", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "StrOutputParserInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "schema", - "output_parser", - "StrOutputParser" - ], - "name": "StrOutputParser" - } - }, - { - "id": 2, - "type": "schema", - "data": "StrOutputParserOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "return_final_only": true - }, - "name": "LLMChain", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "ChainInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "chains", - "llm", - "LLMChain" - ], - "name": "LLMChain" - } - }, - { - "id": 2, - "type": "schema", - "data": "ChainOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - } - ''' -# --- -# name: test_serialize_llmchain_with_non_serializable_arg - ''' - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "chains", - "llm", - "LLMChain" - ], - "kwargs": { - "prompt": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "kwargs": { - "input_variables": [ - "name" - ], - "template": "hello {name}!", - "template_format": "f-string" - }, - "name": "PromptTemplate", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "PromptInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "prompts", - "prompt", - "PromptTemplate" - ], - "name": "PromptTemplate" - } - }, - { - "id": 2, - "type": "schema", - "data": "PromptTemplateOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "llm": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "llms", - "openai", - "OpenAI" - ], - "kwargs": { - "model_name": "davinci", - "temperature": 0.5, - "max_tokens": 256, - "top_p": 1, - "n": 1, - "best_of": 1, - "openai_api_key": { - "lc": 1, - "type": "secret", - "id": [ - "OPENAI_API_KEY" - ] - }, - "openai_proxy": "", - "batch_size": 20, - "max_retries": 2, - "disallowed_special": "all" - }, - "name": "OpenAI", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "OpenAIInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "llms", - "openai", - "OpenAI" - ], - "name": "OpenAI" - } - }, - { - "id": 2, - "type": "schema", - "data": "OpenAIOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "output_key": "text", - "output_parser": { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "schema", - "output_parser", - "StrOutputParser" - ], - "kwargs": {}, - "name": "StrOutputParser", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "StrOutputParserInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "schema", - "output_parser", - "StrOutputParser" - ], - "name": "StrOutputParser" - } - }, - { - "id": 2, - "type": "schema", - "data": "StrOutputParserOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - }, - "return_final_only": true - }, - "name": "LLMChain", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "ChainInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "chains", - "llm", - "LLMChain" - ], - "name": "LLMChain" - } - }, - { - "id": 2, - "type": "schema", - "data": "ChainOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - } - ''' -# --- -# name: test_serialize_openai_llm - ''' - { - "lc": 1, - "type": "constructor", - "id": [ - "langchain", - "llms", - "openai", - "OpenAI" - ], - "kwargs": { - "model_name": "davinci", - "temperature": 0.7, - "max_tokens": 256, - "top_p": 1, - "n": 1, - "best_of": 1, - "openai_api_key": { - "lc": 1, - "type": "secret", - "id": [ - "OPENAI_API_KEY" - ] - }, - "openai_proxy": "", - "batch_size": 20, - "max_retries": 2, - "disallowed_special": "all" - }, - "name": "OpenAI", - "graph": { - "nodes": [ - { - "id": 0, - "type": "schema", - "data": "OpenAIInput" - }, - { - "id": 1, - "type": "runnable", - "data": { - "id": [ - "langchain", - "llms", - "openai", - "OpenAI" - ], - "name": "OpenAI" - } - }, - { - "id": 2, - "type": "schema", - "data": "OpenAIOutput" - } - ], - "edges": [ - { - "source": 0, - "target": 1 - }, - { - "source": 1, - "target": 2 - } - ] - } - } - ''' -# --- diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 36fdf1d776d..288209623d2 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -386,7 +386,7 @@ class BaseChatOpenAI(BaseChatModel): ) return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: @@ -464,6 +464,7 @@ class BaseChatOpenAI(BaseChatModel): values["async_client"] = openai.AsyncOpenAI( **client_params, **async_specific ).chat.completions + return values @property @@ -952,12 +953,17 @@ class BaseChatOpenAI(BaseChatModel): tool_choice: Optional[ Union[dict, str, Literal["auto", "none", "required", "any"], bool] ] = None, + strict: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. Assumes model is compatible with OpenAI tool-calling API. + .. versionchanged:: 0.1.21 + + Support for ``strict`` argument added. + Args: tools: A list of tool definitions to bind to this chat model. Supports any tool definition handled by @@ -970,11 +976,23 @@ class BaseChatOpenAI(BaseChatModel): - ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called. - dict of the form ``{"type": "function", "function": {"name": <>}}``: calls <> tool. - ``False`` or ``None``: no effect, default OpenAI behavior. + strict: If True, model output is guaranteed to exactly match the JSON Schema + provided in the tool definition. If True, the input schema will be + validated according to + https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. + If False, input schema will not be validated and model output will not + be validated. + If None, ``strict`` argument will not be passed to the model. + + .. versionadded:: 0.1.21 + kwargs: Any additional parameters are passed directly to ``self.bind(**kwargs)``. """ # noqa: E501 - formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + formatted_tools = [ + convert_to_openai_tool(tool, strict=strict) for tool in tools + ] if tool_choice: if isinstance(tool_choice, str): # tool_choice is a tool/function name @@ -1018,6 +1036,7 @@ class BaseChatOpenAI(BaseChatModel): *, method: Literal["function_calling", "json_mode"] = "function_calling", include_raw: Literal[True] = True, + strict: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _AllReturnType]: ... @@ -1028,6 +1047,7 @@ class BaseChatOpenAI(BaseChatModel): *, method: Literal["function_calling", "json_mode"] = "function_calling", include_raw: Literal[False] = False, + strict: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: ... @@ -1037,10 +1057,15 @@ class BaseChatOpenAI(BaseChatModel): *, method: Literal["function_calling", "json_mode"] = "function_calling", include_raw: bool = False, + strict: Optional[bool] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: """Model wrapper that returns outputs formatted to match the given schema. + .. versionchanged:: 0.1.21 + + Support for ``strict`` argument added. + Args: schema: The output schema. Can be passed in as: @@ -1060,7 +1085,7 @@ class BaseChatOpenAI(BaseChatModel): Added support for TypedDict class. method: - The method for steering model generation, either "function_calling" + The method for steering model generation, one of "function_calling" or "json_mode". If "function_calling" then the schema will be converted to an OpenAI function and the returned model will make use of the function-calling API. If "json_mode" then OpenAI's JSON mode will be @@ -1073,6 +1098,22 @@ class BaseChatOpenAI(BaseChatModel): response will be returned. If an error occurs during output parsing it will be caught and returned as well. The final output is always a dict with keys "raw", "parsed", and "parsing_error". + strict: If True and ``method`` = "function_calling", model output is + guaranteed to exactly match the schema + If True, the input schema will also be + validated according to + https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. + If False, input schema will not be validated and model output will not + be validated. + If None, ``strict`` argument will not be passed to the model. + + .. versionadded:: 0.1.21 + + .. note:: Planned breaking change in version `0.2.0` + + ``strict`` will default to True when ``method`` is + "function_calling" as of version `0.2.0`. + kwargs: Additional keyword args aren't supported. Returns: A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. @@ -1087,7 +1128,16 @@ class BaseChatOpenAI(BaseChatModel): - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. - ``"parsing_error"``: Optional[BaseException] - Example: schema=Pydantic class, method="function_calling", include_raw=False: + Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True: + .. note:: Valid schemas when using ``strict`` = True + + OpenAI has a number of restrictions on what types of schemas can be + provided if ``strict`` = True. When using Pydantic, our model cannot + specify any Field metadata (like min/max constraints) and fields cannot + have default values. + + See all constraints here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas + .. code-block:: python from typing import Optional @@ -1100,16 +1150,15 @@ class BaseChatOpenAI(BaseChatModel): '''An answer to the user question along with justification for the answer.''' answer: str - # If we provide default values and/or descriptions for fields, these will be passed - # to the model. This is an important part of improving a model's ability to - # correctly return structured outputs. justification: Optional[str] = Field( - default=None, description="A justification for the answer." + default=..., description="A justification for the answer." ) - llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification) + llm = ChatOpenAI(model="gpt-4o", temperature=0) + structured_llm = llm.with_structured_output( + AnswerWithJustification, strict=True + ) structured_llm.invoke( "What weighs more a pound of bricks or a pound of feathers" @@ -1134,7 +1183,7 @@ class BaseChatOpenAI(BaseChatModel): justification: str - llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + llm = ChatOpenAI(model="gpt-4o", temperature=0) structured_llm = llm.with_structured_output( AnswerWithJustification, include_raw=True ) @@ -1167,7 +1216,7 @@ class BaseChatOpenAI(BaseChatModel): ] - llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + llm = ChatOpenAI(model="gpt-4o", temperature=0) structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm.invoke( @@ -1196,7 +1245,7 @@ class BaseChatOpenAI(BaseChatModel): } } - llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + llm = ChatOpenAI(model="gpt-4o", temperature=0) structured_llm = llm.with_structured_output(oai_schema) structured_llm.invoke( @@ -1217,7 +1266,7 @@ class BaseChatOpenAI(BaseChatModel): answer: str justification: str - llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) + llm = ChatOpenAI(model="gpt-4o", temperature=0) structured_llm = llm.with_structured_output( AnswerWithJustification, method="json_mode", @@ -1226,11 +1275,11 @@ class BaseChatOpenAI(BaseChatModel): structured_llm.invoke( "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n" "What's heavier a pound of bricks or a pound of feathers?" ) # -> { - # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'), # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), # 'parsing_error': None # } @@ -1242,11 +1291,11 @@ class BaseChatOpenAI(BaseChatModel): structured_llm.invoke( "Answer the following question. " - "Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" + "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n" "What's heavier a pound of bricks or a pound of feathers?" ) # -> { - # 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), + # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'), # 'parsed': { # 'answer': 'They are both the same weight.', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' @@ -1256,6 +1305,10 @@ class BaseChatOpenAI(BaseChatModel): """ # noqa: E501 if kwargs: raise ValueError(f"Received unsupported arguments {kwargs}") + if strict is not None and method != "function_calling": + raise ValueError( + "Argument `strict` is only supported for `method`='function_calling'" + ) is_pydantic_schema = _is_pydantic_class(schema) if method == "function_calling": if schema is None: @@ -1265,7 +1318,10 @@ class BaseChatOpenAI(BaseChatModel): ) tool_name = convert_to_openai_tool(schema)["function"]["name"] llm = self.bind_tools( - [schema], tool_choice=tool_name, parallel_tool_calls=False + [schema], + tool_choice=tool_name, + parallel_tool_calls=False, + strict=strict, ) if is_pydantic_schema: output_parser: OutputParserLike = PydanticToolsParser( @@ -1498,7 +1554,10 @@ class ChatOpenAI(BaseChatOpenAI): ) - llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) + llm_with_tools = llm.bind_tools( + [GetWeather, GetPopulation] + # strict = True # enforce tool args schema is respected + ) ai_msg = llm_with_tools.invoke( "Which city is hotter today and which is bigger: LA or NY?" ) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 4b652229af8..1624363ed0e 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -4,6 +4,7 @@ import base64 from typing import Any, AsyncIterator, List, Optional, cast import httpx +import openai import pytest from langchain_core.callbacks import CallbackManager from langchain_core.messages import ( @@ -19,6 +20,12 @@ from langchain_core.messages import ( from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_standard_tests.integration_tests.chat_models import ( + _validate_tool_call_message, +) +from langchain_standard_tests.integration_tests.chat_models import ( + magic_function as invalid_magic_function, +) from langchain_openai import ChatOpenAI from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -750,3 +757,77 @@ def test_image_token_counting_png() -> None: ] actual = model.get_num_tokens_from_messages([message]) assert expected == actual + + +def test_tool_calling_strict() -> None: + """Test tool calling with strict=True.""" + + class magic_function(BaseModel): + """Applies a magic function to an input.""" + + input: int + + model = ChatOpenAI(model="gpt-4o", temperature=0) + model_with_tools = model.bind_tools([magic_function], strict=True) + + # invalid_magic_function adds metadata to schema that isn't supported by OpenAI. + model_with_invalid_tool_schema = model.bind_tools( + [invalid_magic_function], strict=True + ) + + # Test invoke + query = "What is the value of magic_function(3)? Use the tool." + response = model_with_tools.invoke(query) + _validate_tool_call_message(response) + + # Test invalid tool schema + with pytest.raises(openai.BadRequestError): + model_with_invalid_tool_schema.invoke(query) + + # Test stream + full: Optional[BaseMessageChunk] = None + for chunk in model_with_tools.stream(query): + full = chunk if full is None else full + chunk # type: ignore + assert isinstance(full, AIMessage) + _validate_tool_call_message(full) + + # Test invalid tool schema + with pytest.raises(openai.BadRequestError): + next(model_with_invalid_tool_schema.stream(query)) + + +def test_structured_output_strict() -> None: + """Test to verify structured output with strict=True.""" + + from pydantic import BaseModel as BaseModelProper + from pydantic import Field as FieldProper + + model = ChatOpenAI(model="gpt-4o", temperature=0) + + class Joke(BaseModelProper): + """Joke to tell user.""" + + setup: str = FieldProper(description="question to set up a joke") + punchline: str = FieldProper(description="answer to resolve the joke") + + # Pydantic class + # Type ignoring since the interface only officially supports pydantic 1 + # or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2. + # We'll need to do a pass updating the type signatures. + chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type] + result = chat.invoke("Tell me a joke about cats.") + assert isinstance(result, Joke) + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, Joke) + + # Schema + chat = model.with_structured_output(Joke.model_json_schema(), strict=True) + result = chat.invoke("Tell me a joke about cats.") + assert isinstance(result, dict) + assert set(result.keys()) == {"setup", "punchline"} + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + assert isinstance(chunk, dict) # for mypy + assert set(chunk.keys()) == {"setup", "punchline"}