mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-27 05:20:34 +00:00
core[patch], openai[patch]: enable strict tool calling (#25111)
Introduced https://openai.com/index/introducing-structured-outputs-in-the-api/
This commit is contained in:
parent
5d10139fc7
commit
78403a3746
@ -56,23 +56,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "e817fe2e-4f1d-4533-b19e-2400b1cf6ce8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdin",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Enter your OpenAI API key: ········\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter your OpenAI API key: \")"
|
||||
"if not os.environ.get(\"OPENAI_API_KEY\"):\n",
|
||||
" os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter your OpenAI API key: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -126,7 +119,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "522686de",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -281,12 +274,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 4,
|
||||
"id": "b7ea7690-ec7a-4337-b392-e87d1f39a6ec",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GetWeather(BaseModel):\n",
|
||||
@ -322,6 +315,45 @@
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "67b0f63d-15e6-45e0-9e86-2852ddcff54f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### ``strict=True``\n",
|
||||
"\n",
|
||||
".. info Requires ``langchain-openai==0.1.21rc1``\n",
|
||||
"\n",
|
||||
" As of Aug 6, 2024, OpenAI supports a `strict` argument when calling tools that will enforce that the tool argument schema is respected by the model. See more here: https://platform.openai.com/docs/guides/function-calling\n",
|
||||
"\n",
|
||||
" **Note**: If ``strict=True`` the tool definition will also be validated, and a subset of JSON schema are accepted. Crucially, schema cannot have optional args (those with default values). Read the full docs on what types of schema are supported here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "dc8ac4f1-4039-4392-90c1-2d8331cd6910",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_VYEfpPDh3npMQ95J9EWmWvSn', 'function': {'arguments': '{\"location\":\"San Francisco, CA\"}', 'name': 'GetWeather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 68, 'total_tokens': 85}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_3aa7262c27', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-a4c6749b-adbb-45c7-8b17-8d6835d5c443-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': 'call_VYEfpPDh3npMQ95J9EWmWvSn', 'type': 'tool_call'}], usage_metadata={'input_tokens': 68, 'output_tokens': 17, 'total_tokens': 85})"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm_with_tools = llm.bind_tools([GetWeather], strict=True)\n",
|
||||
"ai_msg = llm_with_tools.invoke(\n",
|
||||
" \"what is the weather like in San Francisco\",\n",
|
||||
")\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "768d1ae4-4b1a-48eb-a329-c8d5051067a3",
|
||||
@ -412,9 +444,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "poetry-venv-2",
|
||||
"display_name": "poetry-venv-311",
|
||||
"language": "python",
|
||||
"name": "poetry-venv-2"
|
||||
"name": "poetry-venv-311"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -322,6 +322,8 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
|
||||
def convert_to_openai_function(
|
||||
function: Union[Dict[str, Any], Type, Callable, BaseTool],
|
||||
*,
|
||||
strict: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI function.
|
||||
|
||||
@ -330,6 +332,9 @@ def convert_to_openai_function(
|
||||
Tool object, or a Python function. If a dictionary is passed in, it is
|
||||
assumed to already be a valid OpenAI function or a JSON schema with
|
||||
top-level 'title' and 'description' keys specified.
|
||||
strict: If True, model output is guaranteed to exactly match the JSON Schema
|
||||
provided in the function definition. If None, ``strict`` argument will not
|
||||
be included in function definition.
|
||||
|
||||
Returns:
|
||||
A dict version of the passed in function which is compatible with the OpenAI
|
||||
@ -344,25 +349,27 @@ def convert_to_openai_function(
|
||||
if isinstance(function, dict) and all(
|
||||
k in function for k in ("name", "description", "parameters")
|
||||
):
|
||||
return function
|
||||
oai_function = function
|
||||
# a JSON schema with title and description
|
||||
elif isinstance(function, dict) and all(
|
||||
k in function for k in ("title", "description", "properties")
|
||||
):
|
||||
function = function.copy()
|
||||
return {
|
||||
oai_function = {
|
||||
"name": function.pop("title"),
|
||||
"description": function.pop("description"),
|
||||
"parameters": function,
|
||||
}
|
||||
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||
return cast(Dict, convert_pydantic_to_openai_function(function))
|
||||
oai_function = cast(Dict, convert_pydantic_to_openai_function(function))
|
||||
elif is_typeddict(function):
|
||||
return cast(Dict, _convert_typed_dict_to_openai_function(cast(Type, function)))
|
||||
oai_function = cast(
|
||||
Dict, _convert_typed_dict_to_openai_function(cast(Type, function))
|
||||
)
|
||||
elif isinstance(function, BaseTool):
|
||||
return cast(Dict, format_tool_to_openai_function(function))
|
||||
oai_function = cast(Dict, format_tool_to_openai_function(function))
|
||||
elif callable(function):
|
||||
return cast(Dict, convert_python_function_to_openai_function(function))
|
||||
oai_function = cast(Dict, convert_python_function_to_openai_function(function))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"
|
||||
@ -371,9 +378,18 @@ def convert_to_openai_function(
|
||||
" 'title' and 'description' keys."
|
||||
)
|
||||
|
||||
if strict is not None:
|
||||
oai_function["strict"] = strict
|
||||
# As of 08/06/24, OpenAI requires that additionalProperties be supplied and set
|
||||
# to False if strict is True.
|
||||
oai_function["parameters"]["additionalProperties"] = False
|
||||
return oai_function
|
||||
|
||||
|
||||
def convert_to_openai_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
*,
|
||||
strict: Optional[bool] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI tool.
|
||||
|
||||
@ -382,6 +398,9 @@ def convert_to_openai_tool(
|
||||
BaseTool. If a dictionary is passed in, it is assumed to already be a valid
|
||||
OpenAI tool, OpenAI function, or a JSON schema with top-level 'title' and
|
||||
'description' keys specified.
|
||||
strict: If True, model output is guaranteed to exactly match the JSON Schema
|
||||
provided in the function definition. If None, ``strict`` argument will not
|
||||
be included in tool definition.
|
||||
|
||||
Returns:
|
||||
A dict version of the passed in tool which is compatible with the
|
||||
@ -389,8 +408,9 @@ def convert_to_openai_tool(
|
||||
"""
|
||||
if isinstance(tool, dict) and tool.get("type") == "function" and "function" in tool:
|
||||
return tool
|
||||
function = convert_to_openai_function(tool)
|
||||
return {"type": "function", "function": function}
|
||||
oai_function = convert_to_openai_function(tool, strict=strict)
|
||||
oai_tool: Dict[str, Any] = {"type": "function", "function": oai_function}
|
||||
return oai_tool
|
||||
|
||||
|
||||
def tool_example_to_messages(
|
||||
|
@ -113,7 +113,11 @@ def test_configurable() -> None:
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "foo", "description": "foo", "parameters": {}},
|
||||
"function": {
|
||||
"name": "foo",
|
||||
"description": "foo",
|
||||
"parameters": {},
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -59,817 +59,3 @@
|
||||
# name: test_person_with_kwargs
|
||||
'{"lc":1,"type":"constructor","id":["tests","unit_tests","load","test_dump","Person"],"kwargs":{"secret":{"lc":1,"type":"secret","id":["SECRET"]},"you_can_see_me":"hello"}}'
|
||||
# ---
|
||||
# name: test_serialize_llmchain
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
},
|
||||
"name": "PromptTemplate",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"name": "PromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "PromptTemplateOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model_name": "davinci",
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 256,
|
||||
"top_p": 1,
|
||||
"n": 1,
|
||||
"best_of": 1,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
"openai_proxy": "",
|
||||
"batch_size": 20,
|
||||
"max_retries": 2,
|
||||
"disallowed_special": "all"
|
||||
},
|
||||
"name": "OpenAI",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "OpenAIInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"name": "OpenAI"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "OpenAIOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"output_key": "text",
|
||||
"output_parser": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"output_parser",
|
||||
"StrOutputParser"
|
||||
],
|
||||
"kwargs": {},
|
||||
"name": "StrOutputParser",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "StrOutputParserInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"output_parser",
|
||||
"StrOutputParser"
|
||||
],
|
||||
"name": "StrOutputParser"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "StrOutputParserOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"return_final_only": true
|
||||
},
|
||||
"name": "LLMChain",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "ChainInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"name": "LLMChain"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "ChainOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_chat
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatPromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
},
|
||||
"name": "PromptTemplate",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"name": "PromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "PromptTemplateOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "ChatPromptTemplate",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatPromptTemplate"
|
||||
],
|
||||
"name": "ChatPromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "ChatPromptTemplateOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"openai",
|
||||
"ChatOpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model_name": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
"openai_proxy": "",
|
||||
"max_retries": 2,
|
||||
"n": 1
|
||||
},
|
||||
"name": "ChatOpenAI",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "ChatOpenAIInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"openai",
|
||||
"ChatOpenAI"
|
||||
],
|
||||
"name": "ChatOpenAI"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "ChatOpenAIOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"output_key": "text",
|
||||
"output_parser": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"output_parser",
|
||||
"StrOutputParser"
|
||||
],
|
||||
"kwargs": {},
|
||||
"name": "StrOutputParser",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "StrOutputParserInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"output_parser",
|
||||
"StrOutputParser"
|
||||
],
|
||||
"name": "StrOutputParser"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "StrOutputParserOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"return_final_only": true
|
||||
},
|
||||
"name": "LLMChain",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "ChainInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"name": "LLMChain"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "ChainOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_with_non_serializable_arg
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
},
|
||||
"name": "PromptTemplate",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "PromptInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"name": "PromptTemplate"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "PromptTemplateOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model_name": "davinci",
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 256,
|
||||
"top_p": 1,
|
||||
"n": 1,
|
||||
"best_of": 1,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
"openai_proxy": "",
|
||||
"batch_size": 20,
|
||||
"max_retries": 2,
|
||||
"disallowed_special": "all"
|
||||
},
|
||||
"name": "OpenAI",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "OpenAIInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"name": "OpenAI"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "OpenAIOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"output_key": "text",
|
||||
"output_parser": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"output_parser",
|
||||
"StrOutputParser"
|
||||
],
|
||||
"kwargs": {},
|
||||
"name": "StrOutputParser",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "StrOutputParserInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"output_parser",
|
||||
"StrOutputParser"
|
||||
],
|
||||
"name": "StrOutputParser"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "StrOutputParserOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"return_final_only": true
|
||||
},
|
||||
"name": "LLMChain",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "ChainInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"name": "LLMChain"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "ChainOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_openai_llm
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model_name": "davinci",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 256,
|
||||
"top_p": 1,
|
||||
"n": 1,
|
||||
"best_of": 1,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
"openai_proxy": "",
|
||||
"batch_size": 20,
|
||||
"max_retries": 2,
|
||||
"disallowed_special": "all"
|
||||
},
|
||||
"name": "OpenAI",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{
|
||||
"id": 0,
|
||||
"type": "schema",
|
||||
"data": "OpenAIInput"
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"name": "OpenAI"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "schema",
|
||||
"data": "OpenAIOutput"
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": 0,
|
||||
"target": 1
|
||||
},
|
||||
{
|
||||
"source": 1,
|
||||
"target": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
|
@ -386,7 +386,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
@ -464,6 +464,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
@ -952,12 +953,17 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
.. versionchanged:: 0.1.21
|
||||
|
||||
Support for ``strict`` argument added.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Supports any tool definition handled by
|
||||
@ -970,11 +976,23 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
- ``"any"`` or ``"required"`` or ``True``: force at least one tool to be called.
|
||||
- dict of the form ``{"type": "function", "function": {"name": <<tool_name>>}}``: calls <<tool_name>> tool.
|
||||
- ``False`` or ``None``: no effect, default OpenAI behavior.
|
||||
strict: If True, model output is guaranteed to exactly match the JSON Schema
|
||||
provided in the tool definition. If True, the input schema will be
|
||||
validated according to
|
||||
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
|
||||
If False, input schema will not be validated and model output will not
|
||||
be validated.
|
||||
If None, ``strict`` argument will not be passed to the model.
|
||||
|
||||
.. versionadded:: 0.1.21
|
||||
|
||||
kwargs: Any additional parameters are passed directly to
|
||||
``self.bind(**kwargs)``.
|
||||
""" # noqa: E501
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
formatted_tools = [
|
||||
convert_to_openai_tool(tool, strict=strict) for tool in tools
|
||||
]
|
||||
if tool_choice:
|
||||
if isinstance(tool_choice, str):
|
||||
# tool_choice is a tool/function name
|
||||
@ -1018,6 +1036,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: Literal[True] = True,
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _AllReturnType]: ...
|
||||
|
||||
@ -1028,6 +1047,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: Literal[False] = False,
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ...
|
||||
|
||||
@ -1037,10 +1057,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
.. versionchanged:: 0.1.21
|
||||
|
||||
Support for ``strict`` argument added.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
The output schema. Can be passed in as:
|
||||
@ -1060,7 +1085,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
Added support for TypedDict class.
|
||||
|
||||
method:
|
||||
The method for steering model generation, either "function_calling"
|
||||
The method for steering model generation, one of "function_calling"
|
||||
or "json_mode". If "function_calling" then the schema will be converted
|
||||
to an OpenAI function and the returned model will make use of the
|
||||
function-calling API. If "json_mode" then OpenAI's JSON mode will be
|
||||
@ -1073,6 +1098,22 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
strict: If True and ``method`` = "function_calling", model output is
|
||||
guaranteed to exactly match the schema
|
||||
If True, the input schema will also be
|
||||
validated according to
|
||||
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
|
||||
If False, input schema will not be validated and model output will not
|
||||
be validated.
|
||||
If None, ``strict`` argument will not be passed to the model.
|
||||
|
||||
.. versionadded:: 0.1.21
|
||||
|
||||
.. note:: Planned breaking change in version `0.2.0`
|
||||
|
||||
``strict`` will default to True when ``method`` is
|
||||
"function_calling" as of version `0.2.0`.
|
||||
kwargs: Additional keyword args aren't supported.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
@ -1087,7 +1128,16 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||
- ``"parsing_error"``: Optional[BaseException]
|
||||
|
||||
Example: schema=Pydantic class, method="function_calling", include_raw=False:
|
||||
Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True:
|
||||
.. note:: Valid schemas when using ``strict`` = True
|
||||
|
||||
OpenAI has a number of restrictions on what types of schemas can be
|
||||
provided if ``strict`` = True. When using Pydantic, our model cannot
|
||||
specify any Field metadata (like min/max constraints) and fields cannot
|
||||
have default values.
|
||||
|
||||
See all constraints here: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
@ -1100,16 +1150,15 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
'''An answer to the user question along with justification for the answer.'''
|
||||
|
||||
answer: str
|
||||
# If we provide default values and/or descriptions for fields, these will be passed
|
||||
# to the model. This is an important part of improving a model's ability to
|
||||
# correctly return structured outputs.
|
||||
justification: Optional[str] = Field(
|
||||
default=None, description="A justification for the answer."
|
||||
default=..., description="A justification for the answer."
|
||||
)
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, strict=True
|
||||
)
|
||||
|
||||
structured_llm.invoke(
|
||||
"What weighs more a pound of bricks or a pound of feathers"
|
||||
@ -1134,7 +1183,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
justification: str
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification, include_raw=True
|
||||
)
|
||||
@ -1167,7 +1216,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
]
|
||||
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
structured_llm = llm.with_structured_output(AnswerWithJustification)
|
||||
|
||||
structured_llm.invoke(
|
||||
@ -1196,7 +1245,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
}
|
||||
}
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
structured_llm = llm.with_structured_output(oai_schema)
|
||||
|
||||
structured_llm.invoke(
|
||||
@ -1217,7 +1266,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
answer: str
|
||||
justification: str
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
structured_llm = llm.with_structured_output(
|
||||
AnswerWithJustification,
|
||||
method="json_mode",
|
||||
@ -1226,11 +1275,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
|
||||
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
@ -1242,11 +1291,11 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\\n\\n"
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'raw': AIMessage(content='{\\n "answer": "They are both the same weight.",\\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \\n}'),
|
||||
# 'parsed': {
|
||||
# 'answer': 'They are both the same weight.',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
|
||||
@ -1256,6 +1305,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
if strict is not None and method != "function_calling":
|
||||
raise ValueError(
|
||||
"Argument `strict` is only supported for `method`='function_calling'"
|
||||
)
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
@ -1265,7 +1318,10 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||
llm = self.bind_tools(
|
||||
[schema], tool_choice=tool_name, parallel_tool_calls=False
|
||||
[schema],
|
||||
tool_choice=tool_name,
|
||||
parallel_tool_calls=False,
|
||||
strict=strict,
|
||||
)
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
@ -1498,7 +1554,10 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
)
|
||||
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[GetWeather, GetPopulation]
|
||||
# strict = True # enforce tool args schema is respected
|
||||
)
|
||||
ai_msg = llm_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?"
|
||||
)
|
||||
|
@ -4,6 +4,7 @@ import base64
|
||||
from typing import Any, AsyncIterator, List, Optional, cast
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import (
|
||||
@ -19,6 +20,12 @@ from langchain_core.messages import (
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_standard_tests.integration_tests.chat_models import (
|
||||
_validate_tool_call_message,
|
||||
)
|
||||
from langchain_standard_tests.integration_tests.chat_models import (
|
||||
magic_function as invalid_magic_function,
|
||||
)
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
@ -750,3 +757,77 @@ def test_image_token_counting_png() -> None:
|
||||
]
|
||||
actual = model.get_num_tokens_from_messages([message])
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tool_calling_strict() -> None:
|
||||
"""Test tool calling with strict=True."""
|
||||
|
||||
class magic_function(BaseModel):
|
||||
"""Applies a magic function to an input."""
|
||||
|
||||
input: int
|
||||
|
||||
model = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
model_with_tools = model.bind_tools([magic_function], strict=True)
|
||||
|
||||
# invalid_magic_function adds metadata to schema that isn't supported by OpenAI.
|
||||
model_with_invalid_tool_schema = model.bind_tools(
|
||||
[invalid_magic_function], strict=True
|
||||
)
|
||||
|
||||
# Test invoke
|
||||
query = "What is the value of magic_function(3)? Use the tool."
|
||||
response = model_with_tools.invoke(query)
|
||||
_validate_tool_call_message(response)
|
||||
|
||||
# Test invalid tool schema
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
model_with_invalid_tool_schema.invoke(query)
|
||||
|
||||
# Test stream
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in model_with_tools.stream(query):
|
||||
full = chunk if full is None else full + chunk # type: ignore
|
||||
assert isinstance(full, AIMessage)
|
||||
_validate_tool_call_message(full)
|
||||
|
||||
# Test invalid tool schema
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
next(model_with_invalid_tool_schema.stream(query))
|
||||
|
||||
|
||||
def test_structured_output_strict() -> None:
|
||||
"""Test to verify structured output with strict=True."""
|
||||
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
model = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
|
||||
class Joke(BaseModelProper):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = FieldProper(description="question to set up a joke")
|
||||
punchline: str = FieldProper(description="answer to resolve the joke")
|
||||
|
||||
# Pydantic class
|
||||
# Type ignoring since the interface only officially supports pydantic 1
|
||||
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
|
||||
# We'll need to do a pass updating the type signatures.
|
||||
chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type]
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, Joke)
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, Joke)
|
||||
|
||||
# Schema
|
||||
chat = model.with_structured_output(Joke.model_json_schema(), strict=True)
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict) # for mypy
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
Loading…
Reference in New Issue
Block a user