core[patch], openai[patch]: enable strict tool calling (#25111)

Introduced
https://openai.com/index/introducing-structured-outputs-in-the-api/
This commit is contained in:
Bagatur 2024-08-06 14:21:06 -07:00 committed by GitHub
parent 5d10139fc7
commit 78403a3746
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 241 additions and 859 deletions

View File

@ -56,23 +56,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8", "id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdin",
"output_type": "stream",
"text": [
"Enter your OpenAI API key: ········\n"
]
}
],
"source": [ "source": [
"import getpass\n", "import getpass\n",
"import os\n", "import os\n",
"\n", "\n",
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter your OpenAI API key: \")" "if not os.environ.get(\"OPENAI_API_KEY\"):\n",
" os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter your OpenAI API key: \")"
] ]
}, },
{ {
@ -126,7 +119,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"id": "522686de", "id": "522686de",
"metadata": { "metadata": {
"tags": [] "tags": []
@ -281,12 +274,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 4,
"id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec", "id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain_core.pydantic_v1 import BaseModel, Field\n", "from pydantic import BaseModel, Field\n",
"\n", "\n",
"\n", "\n",
"class GetWeather(BaseModel):\n", "class GetWeather(BaseModel):\n",
@ -322,6 +315,45 @@
"ai_msg" "ai_msg"
] ]
}, },
{
"cell_type": "markdown",
"id": "67b0f63d-15e6-45e0-9e86-2852ddcff54f",
"metadata": {},
"source": [
"### ``strict=True``\n",
"\n",
".. info Requires ``langchain-openai==0.1.21rc1``\n",
"\n",
" As of Aug 6, 2024, OpenAI supports a `strict` argument when calling tools that will enforce that the tool argument schema is respected by the model. See more here: https://platform.openai.com/docs/guides/function-calling\n",
"\n",
" **Note**: If ``strict=True`` the tool definition will also be validated, and a subset of JSON schema are accepted. Crucially, schema cannot have optional args (those with default values). Read the full docs on what types of schema are supported here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dc8ac4f1-4039-4392-90c1-2d8331cd6910",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_VYEfpPDh3npMQ95J9EWmWvSn', 'function': {'arguments': '{\"location\":\"San Francisco, CA\"}', 'name': 'GetWeather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 68, 'total_tokens': 85}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_3aa7262c27', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-a4c6749b-adbb-45c7-8b17-8d6835d5c443-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': 'call_VYEfpPDh3npMQ95J9EWmWvSn', 'type': 'tool_call'}], usage_metadata={'input_tokens': 68, 'output_tokens': 17, 'total_tokens': 85})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_with_tools = llm.bind_tools([GetWeather], strict=True)\n",
"ai_msg = llm_with_tools.invoke(\n",
" \"what is the weather like in San Francisco\",\n",
")\n",
"ai_msg"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "768d1ae4-4b1a-48eb-a329-c8d5051067a3", "id": "768d1ae4-4b1a-48eb-a329-c8d5051067a3",
@ -412,9 +444,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "poetry-venv-2", "display_name": "poetry-venv-311",
"language": "python", "language": "python",
"name": "poetry-venv-2" "name": "poetry-venv-311"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {

View File

@ -322,6 +322,8 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
def convert_to_openai_function( def convert_to_openai_function(
function: Union[Dict[str, Any], Type, Callable, BaseTool], function: Union[Dict[str, Any], Type, Callable, BaseTool],
*,
strict: Optional[bool] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI function. """Convert a raw function/class to an OpenAI function.
@ -330,6 +332,9 @@ def convert_to_openai_function(
Tool object, or a Python function. If a dictionary is passed in, it is Tool object, or a Python function. If a dictionary is passed in, it is
assumed to already be a valid OpenAI function or a JSON schema with assumed to already be a valid OpenAI function or a JSON schema with
top-level 'title' and 'description' keys specified. top-level 'title' and 'description' keys specified.
strict: If True, model output is guaranteed to exactly match the JSON Schema
provided in the function definition. If None, ``strict`` argument will not
be included in function definition.
Returns: Returns:
A dict version of the passed in function which is compatible with the OpenAI A dict version of the passed in function which is compatible with the OpenAI
@ -344,25 +349,27 @@ def convert_to_openai_function(
if isinstance(function, dict) and all( if isinstance(function, dict) and all(
k in function for k in ("name", "description", "parameters") k in function for k in ("name", "description", "parameters")
): ):
return function oai_function = function
# a JSON schema with title and description # a JSON schema with title and description
elif isinstance(function, dict) and all( elif isinstance(function, dict) and all(
k in function for k in ("title", "description", "properties") k in function for k in ("title", "description", "properties")
): ):
function = function.copy() function = function.copy()
return { oai_function = {
"name": function.pop("title"), "name": function.pop("title"),
"description": function.pop("description"), "description": function.pop("description"),
"parameters": function, "parameters": function,
} }
elif isinstance(function, type) and is_basemodel_subclass(function): elif isinstance(function, type) and is_basemodel_subclass(function):
return cast(Dict, convert_pydantic_to_openai_function(function)) oai_function = cast(Dict, convert_pydantic_to_openai_function(function))
elif is_typeddict(function): elif is_typeddict(function):
return cast(Dict, _convert_typed_dict_to_openai_function(cast(Type, function))) oai_function = cast(
Dict, _convert_typed_dict_to_openai_function(cast(Type, function))
)
elif isinstance(function, BaseTool): elif isinstance(function, BaseTool):
return cast(Dict, format_tool_to_openai_function(function)) oai_function = cast(Dict, format_tool_to_openai_function(function))
elif callable(function): elif callable(function):
return cast(Dict, convert_python_function_to_openai_function(function)) oai_function = cast(Dict, convert_python_function_to_openai_function(function))
else: else:
raise ValueError( raise ValueError(
f"Unsupported function\n\n{function}\n\nFunctions must be passed in" f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
@ -371,9 +378,18 @@ def convert_to_openai_function(
" 'title' and 'description' keys." " 'title' and 'description' keys."
) )
if strict is not None:
oai_function["strict"] = strict
# As of 08/06/24, OpenAI requires that additionalProperties be supplied and set
# to False if strict is True.
oai_function["parameters"]["additionalProperties"] = False
return oai_function
def convert_to_openai_tool( def convert_to_openai_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
*,
strict: Optional[bool] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool. """Convert a raw function/class to an OpenAI tool.
@ -382,6 +398,9 @@ def convert_to_openai_tool(
BaseTool. If a dictionary is passed in, it is assumed to already be a valid BaseTool. If a dictionary is passed in, it is assumed to already be a valid
OpenAI tool, OpenAI function, or a JSON schema with top-level 'title' and OpenAI tool, OpenAI function, or a JSON schema with top-level 'title' and
'description' keys specified. 'description' keys specified.
strict: If True, model output is guaranteed to exactly match the JSON Schema
provided in the function definition. If None, ``strict`` argument will not
be included in tool definition.
Returns: Returns:
A dict version of the passed in tool which is compatible with the A dict version of the passed in tool which is compatible with the
@ -389,8 +408,9 @@ def convert_to_openai_tool(
""" """
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool: if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool:
return tool return tool
function = convert_to_openai_function(tool) oai_function = convert_to_openai_function(tool, strict=strict)
return {"type": "function", "function": function} oai_tool: Dict[str, Any] = {"type": "function", "function": oai_function}
return oai_tool
def tool_example_to_messages( def tool_example_to_messages(

View File

@ -113,7 +113,11 @@ def test_configurable() -> None:
"tools": [ "tools": [
{ {
"type": "function", "type": "function",
"function": {"name": "foo", "description": "foo", "parameters": {}}, "function": {
"name": "foo",
"description": "foo",
"parameters": {},
},
} }
] ]
}, },

View File

@ -59,817 +59,3 @@
# name: test_person_with_kwargs # name: test_person_with_kwargs
'{"lc":1,"type":"constructor","id":["tests","unit_tests","load","test_dump","Person"],"kwargs":{"secret":{"lc":1,"type":"secret","id":["SECRET"]},"you_can_see_me":"hello"}}' '{"lc":1,"type":"constructor","id":["tests","unit_tests","load","test_dump","Person"],"kwargs":{"secret":{"lc":1,"type":"secret","id":["SECRET"]},"you_can_see_me":"hello"}}'
# --- # ---
# name: test_serialize_llmchain
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"name": "PromptTemplate"
}
},
{
"id": 2,
"type": "schema",
"data": "PromptTemplateOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model_name": "davinci",
"temperature": 0.5,
"max_tokens": 256,
"top_p": 1,
"n": 1,
"best_of": 1,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"openai_proxy": "",
"batch_size": 20,
"max_retries": 2,
"disallowed_special": "all"
},
"name": "OpenAI",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "OpenAIInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"name": "OpenAI"
}
},
{
"id": 2,
"type": "schema",
"data": "OpenAIOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"output_key": "text",
"output_parser": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"kwargs": {},
"name": "StrOutputParser",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "StrOutputParserInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"name": "StrOutputParser"
}
},
{
"id": 2,
"type": "schema",
"data": "StrOutputParserOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"return_final_only": true
},
"name": "LLMChain",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "ChainInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"name": "LLMChain"
}
},
{
"id": 2,
"type": "schema",
"data": "ChainOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
}
'''
# ---
# name: test_serialize_llmchain_chat
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"ChatPromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"name": "PromptTemplate"
}
},
{
"id": 2,
"type": "schema",
"data": "PromptTemplateOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
}
}
}
]
},
"name": "ChatPromptTemplate",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"prompts",
"chat",
"ChatPromptTemplate"
],
"name": "ChatPromptTemplate"
}
},
{
"id": 2,
"type": "schema",
"data": "ChatPromptTemplateOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"kwargs": {
"model_name": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"openai_proxy": "",
"max_retries": 2,
"n": 1
},
"name": "ChatOpenAI",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "ChatOpenAIInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"name": "ChatOpenAI"
}
},
{
"id": 2,
"type": "schema",
"data": "ChatOpenAIOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"output_key": "text",
"output_parser": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"kwargs": {},
"name": "StrOutputParser",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "StrOutputParserInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"name": "StrOutputParser"
}
},
{
"id": 2,
"type": "schema",
"data": "StrOutputParserOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"return_final_only": true
},
"name": "LLMChain",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "ChainInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"name": "LLMChain"
}
},
{
"id": 2,
"type": "schema",
"data": "ChainOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
}
'''
# ---
# name: test_serialize_llmchain_with_non_serializable_arg
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
},
"name": "PromptTemplate",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "PromptInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"name": "PromptTemplate"
}
},
{
"id": 2,
"type": "schema",
"data": "PromptTemplateOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model_name": "davinci",
"temperature": 0.5,
"max_tokens": 256,
"top_p": 1,
"n": 1,
"best_of": 1,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"openai_proxy": "",
"batch_size": 20,
"max_retries": 2,
"disallowed_special": "all"
},
"name": "OpenAI",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "OpenAIInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"name": "OpenAI"
}
},
{
"id": 2,
"type": "schema",
"data": "OpenAIOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"output_key": "text",
"output_parser": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"kwargs": {},
"name": "StrOutputParser",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "StrOutputParserInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"schema",
"output_parser",
"StrOutputParser"
],
"name": "StrOutputParser"
}
},
{
"id": 2,
"type": "schema",
"data": "StrOutputParserOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
},
"return_final_only": true
},
"name": "LLMChain",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "ChainInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"name": "LLMChain"
}
},
{
"id": 2,
"type": "schema",
"data": "ChainOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
}
'''
# ---
# name: test_serialize_openai_llm
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model_name": "davinci",
"temperature": 0.7,
"max_tokens": 256,
"top_p": 1,
"n": 1,
"best_of": 1,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"openai_proxy": "",
"batch_size": 20,
"max_retries": 2,
"disallowed_special": "all"
},
"name": "OpenAI",
"graph": {
"nodes": [
{
"id": 0,
"type": "schema",
"data": "OpenAIInput"
},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"name": "OpenAI"
}
},
{
"id": 2,
"type": "schema",
"data": "OpenAIOutput"
}
],
"edges": [
{
"source": 0,
"target": 1
},
{
"source": 1,
"target": 2
}
]
}
}
'''
# ---

View File

@ -386,7 +386,7 @@ class BaseChatOpenAI(BaseChatModel):
) )
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:
@ -464,6 +464,7 @@ class BaseChatOpenAI(BaseChatModel):
values["async_client"] = openai.AsyncOpenAI( values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific **client_params, **async_specific
).chat.completions ).chat.completions
return values return values
@property @property
@ -952,12 +953,17 @@ class BaseChatOpenAI(BaseChatModel):
tool_choice: Optional[ tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool] Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None, ] = None,
strict: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model. """Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API. Assumes model is compatible with OpenAI tool-calling API.
.. versionchanged:: 0.1.21
Support for ``strict`` argument added.
Args: Args:
tools: A list of tool definitions to bind to this chat model. tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by Supports any tool definition handled by
@ -970,11 +976,23 @@ class BaseChatOpenAI(BaseChatModel):
- ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called. - ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called.
- dict of the form ``{"type": "function", "function": {"name": <<tool_name>>}}``: calls <<tool_name>> tool. - dict of the form ``{"type": "function", "function": {"name": <<tool_name>>}}``: calls <<tool_name>> tool.
- ``False`` or ``None``: no effect, default OpenAI behavior. - ``False`` or ``None``: no effect, default OpenAI behavior.
strict: If True, model output is guaranteed to exactly match the JSON Schema
provided in the tool definition. If True, the input schema will be
validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
If False, input schema will not be validated and model output will not
be validated.
If None, ``strict`` argument will not be passed to the model.
.. versionadded:: 0.1.21
kwargs: Any additional parameters are passed directly to kwargs: Any additional parameters are passed directly to
``self.bind(**kwargs)``. ``self.bind(**kwargs)``.
""" # noqa: E501 """ # noqa: E501
formatted_tools = [convert_to_openai_tool(tool) for tool in tools] formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
if tool_choice: if tool_choice:
if isinstance(tool_choice, str): if isinstance(tool_choice, str):
# tool_choice is a tool/function name # tool_choice is a tool/function name
@ -1018,6 +1036,7 @@ class BaseChatOpenAI(BaseChatModel):
*, *,
method: Literal["function_calling", "json_mode"] = "function_calling", method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: Literal[True] = True, include_raw: Literal[True] = True,
strict: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, _AllReturnType]: ... ) -> Runnable[LanguageModelInput, _AllReturnType]: ...
@ -1028,6 +1047,7 @@ class BaseChatOpenAI(BaseChatModel):
*, *,
method: Literal["function_calling", "json_mode"] = "function_calling", method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: Literal[False] = False, include_raw: Literal[False] = False,
strict: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ... ) -> Runnable[LanguageModelInput, _DictOrPydantic]: ...
@ -1037,10 +1057,15 @@ class BaseChatOpenAI(BaseChatModel):
*, *,
method: Literal["function_calling", "json_mode"] = "function_calling", method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False, include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema. """Model wrapper that returns outputs formatted to match the given schema.
.. versionchanged:: 0.1.21
Support for ``strict`` argument added.
Args: Args:
schema: schema:
The output schema. Can be passed in as: The output schema. Can be passed in as:
@ -1060,7 +1085,7 @@ class BaseChatOpenAI(BaseChatModel):
Added support for TypedDict class. Added support for TypedDict class.
method: method:
The method for steering model generation, either "function_calling" The method for steering model generation, one of "function_calling"
or "json_mode". If "function_calling" then the schema will be converted or "json_mode". If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the to an OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" then OpenAI's JSON mode will be function-calling API. If "json_mode" then OpenAI's JSON mode will be
@ -1073,6 +1098,22 @@ class BaseChatOpenAI(BaseChatModel):
response will be returned. If an error occurs during output parsing it response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error". with keys "raw", "parsed", and "parsing_error".
strict: If True and ``method`` = "function_calling", model output is
guaranteed to exactly match the schema
If True, the input schema will also be
validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
If False, input schema will not be validated and model output will not
be validated.
If None, ``strict`` argument will not be passed to the model.
.. versionadded:: 0.1.21
.. note:: Planned breaking change in version `0.2.0`
``strict`` will default to True when ``method`` is
"function_calling" as of version `0.2.0`.
kwargs: Additional keyword args aren't supported.
Returns: Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
@ -1087,7 +1128,16 @@ class BaseChatOpenAI(BaseChatModel):
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException] - ``"parsing_error"``: Optional[BaseException]
Example: schema=Pydantic class, method="function_calling", include_raw=False: Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True:
.. note:: Valid schemas when using ``strict`` = True
OpenAI has a number of restrictions on what types of schemas can be
provided if ``strict`` = True. When using Pydantic, our model cannot
specify any Field metadata (like min/max constraints) and fields cannot
have default values.
See all constraints here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
.. code-block:: python .. code-block:: python
from typing import Optional from typing import Optional
@ -1100,16 +1150,15 @@ class BaseChatOpenAI(BaseChatModel):
'''An answer to the user question along with justification for the answer.''' '''An answer to the user question along with justification for the answer.'''
answer: str answer: str
# If we provide default values and/or descriptions for fields, these will be passed
# to the model. This is an important part of improving a model's ability to
# correctly return structured outputs.
justification: Optional[str] = Field( justification: Optional[str] = Field(
default=None, description="A justification for the answer." default=..., description="A justification for the answer."
) )
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-4o", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm = llm.with_structured_output(
AnswerWithJustification, strict=True
)
structured_llm.invoke( structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers" "What weighs more a pound of bricks or a pound of feathers"
@ -1134,7 +1183,7 @@ class BaseChatOpenAI(BaseChatModel):
justification: str justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-4o", temperature=0)
structured_llm = llm.with_structured_output( structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True AnswerWithJustification, include_raw=True
) )
@ -1167,7 +1216,7 @@ class BaseChatOpenAI(BaseChatModel):
] ]
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-4o", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke( structured_llm.invoke(
@ -1196,7 +1245,7 @@ class BaseChatOpenAI(BaseChatModel):
} }
} }
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-4o", temperature=0)
structured_llm = llm.with_structured_output(oai_schema) structured_llm = llm.with_structured_output(oai_schema)
structured_llm.invoke( structured_llm.invoke(
@ -1217,7 +1266,7 @@ class BaseChatOpenAI(BaseChatModel):
answer: str answer: str
justification: str justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-4o", temperature=0)
structured_llm = llm.with_structured_output( structured_llm = llm.with_structured_output(
AnswerWithJustification, AnswerWithJustification,
method="json_mode", method="json_mode",
@ -1226,11 +1275,11 @@ class BaseChatOpenAI(BaseChatModel):
structured_llm.invoke( structured_llm.invoke(
"Answer the following question. " "Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?" "What's heavier a pound of bricks or a pound of feathers?"
) )
# -> { # -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'), # 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
# 'parsing_error': None # 'parsing_error': None
# } # }
@ -1242,11 +1291,11 @@ class BaseChatOpenAI(BaseChatModel):
structured_llm.invoke( structured_llm.invoke(
"Answer the following question. " "Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n" "Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
"What's heavier a pound of bricks or a pound of feathers?" "What's heavier a pound of bricks or a pound of feathers?"
) )
# -> { # -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'), # 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
# 'parsed': { # 'parsed': {
# 'answer': 'They are both the same weight.', # 'answer': 'They are both the same weight.',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.' # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
@ -1256,6 +1305,10 @@ class BaseChatOpenAI(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
if strict is not None and method != "function_calling":
raise ValueError(
"Argument `strict` is only supported for `method`='function_calling'"
)
is_pydantic_schema = _is_pydantic_class(schema) is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling": if method == "function_calling":
if schema is None: if schema is None:
@ -1265,7 +1318,10 @@ class BaseChatOpenAI(BaseChatModel):
) )
tool_name = convert_to_openai_tool(schema)["function"]["name"] tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools( llm = self.bind_tools(
[schema], tool_choice=tool_name, parallel_tool_calls=False [schema],
tool_choice=tool_name,
parallel_tool_calls=False,
strict=strict,
) )
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
@ -1498,7 +1554,10 @@ class ChatOpenAI(BaseChatOpenAI):
) )
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) llm_with_tools = llm.bind_tools(
[GetWeather, GetPopulation]
# strict = True # enforce tool args schema is respected
)
ai_msg = llm_with_tools.invoke( ai_msg = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?" "Which city is hotter today and which is bigger: LA or NY?"
) )

View File

@ -4,6 +4,7 @@ import base64
from typing import Any, AsyncIterator, List, Optional, cast from typing import Any, AsyncIterator, List, Optional, cast
import httpx import httpx
import openai
import pytest import pytest
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.messages import ( from langchain_core.messages import (
@ -19,6 +20,12 @@ from langchain_core.messages import (
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_standard_tests.integration_tests.chat_models import (
_validate_tool_call_message,
)
from langchain_standard_tests.integration_tests.chat_models import (
magic_function as invalid_magic_function,
)
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@ -750,3 +757,77 @@ def test_image_token_counting_png() -> None:
] ]
actual = model.get_num_tokens_from_messages([message]) actual = model.get_num_tokens_from_messages([message])
assert expected == actual assert expected == actual
def test_tool_calling_strict() -> None:
"""Test tool calling with strict=True."""
class magic_function(BaseModel):
"""Applies a magic function to an input."""
input: int
model = ChatOpenAI(model="gpt-4o", temperature=0)
model_with_tools = model.bind_tools([magic_function], strict=True)
# invalid_magic_function adds metadata to schema that isn't supported by OpenAI.
model_with_invalid_tool_schema = model.bind_tools(
[invalid_magic_function], strict=True
)
# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
response = model_with_tools.invoke(query)
_validate_tool_call_message(response)
# Test invalid tool schema
with pytest.raises(openai.BadRequestError):
model_with_invalid_tool_schema.invoke(query)
# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
# Test invalid tool schema
with pytest.raises(openai.BadRequestError):
next(model_with_invalid_tool_schema.stream(query))
def test_structured_output_strict() -> None:
"""Test to verify structured output with strict=True."""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
model = ChatOpenAI(model="gpt-4o", temperature=0)
class Joke(BaseModelProper):
"""Joke to tell user."""
setup: str = FieldProper(description="question to set up a joke")
punchline: str = FieldProper(description="answer to resolve the joke")
# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type]
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
# Schema
chat = model.with_structured_output(Joke.model_json_schema(), strict=True)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}