openai[patch]: support structured output and tools (#30581)

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Bagatur 2025-04-02 06:14:02 -07:00 committed by GitHub
parent 32f7695809
commit 111dd90a46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 93 additions and 2 deletions

View File

@ -1477,6 +1477,7 @@ class BaseChatOpenAI(BaseChatModel):
] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
tools: Optional[list] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
@ -1537,6 +1538,51 @@ class BaseChatOpenAI(BaseChatModel):
- None:
``strict`` argument will not be passed to the model.
tools:
A list of tool-like objects to bind to the chat model. Requires that:
- ``method`` is ``"json_schema"`` (default).
- ``strict=True``
- ``include_raw=True``
If a model elects to call a
tool, the resulting ``AIMessage`` in ``"raw"`` will include tool calls.
.. dropdown:: Example
.. code-block:: python
from langchain.chat_models import init_chat_model
from pydantic import BaseModel
class ResponseSchema(BaseModel):
response: str
def get_weather(location: str) -> str:
\"\"\"Get weather at a location.\"\"\"
pass
llm = init_chat_model("openai:gpt-4o-mini")
structured_llm = llm.with_structured_output(
ResponseSchema,
tools=[get_weather],
strict=True,
include_raw=True,
)
structured_llm.invoke("What's the weather in Boston?")
.. code-block:: python
{
"raw": AIMessage(content="", tool_calls=[...], ...),
"parsing_error": None,
"parsed": None,
}
kwargs: Additional keyword args aren't supported.
Returns:
@ -1558,6 +1604,9 @@ class BaseChatOpenAI(BaseChatModel):
Support for ``strict`` argument added.
Support for ``method`` = "json_schema" added.
.. versionchanged:: 0.3.12
Support for ``tools`` added.
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
@ -1642,13 +1691,18 @@ class BaseChatOpenAI(BaseChatModel):
"Received None."
)
response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(
bind_kwargs = dict(
response_format=response_format,
ls_structured_output_format={
"kwargs": {"method": method, "strict": strict},
"schema": convert_to_openai_tool(schema),
},
)
if tools:
bind_kwargs["tools"] = [
convert_to_openai_tool(t, strict=strict) for t in tools
]
llm = self.bind(**bind_kwargs)
if is_pydantic_schema:
output_parser = RunnableLambda(
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
@ -2776,7 +2830,7 @@ def _convert_to_openai_response_format(
def _oai_structured_outputs_parser(
ai_msg: AIMessage, schema: Type[_BM]
) -> PydanticBaseModel:
) -> Optional[PydanticBaseModel]:
if parsed := ai_msg.additional_kwargs.get("parsed"):
if isinstance(parsed, dict):
return schema(**parsed)
@ -2784,6 +2838,8 @@ def _oai_structured_outputs_parser(
return parsed
elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
elif ai_msg.tool_calls:
return None
else:
raise ValueError(
"Structured Output response does not have a 'parsed' field nor a 'refusal' "

View File

@ -1265,3 +1265,38 @@ def test_structured_output_and_tools() -> None:
assert len(full.tool_calls) == 1
tool_call = full.tool_calls[0]
assert tool_call["name"] == "GenerateUsername"
def test_tools_and_structured_output() -> None:
class ResponseFormat(BaseModel):
response: str
explanation: str
llm = ChatOpenAI(model="gpt-4o-mini").with_structured_output(
ResponseFormat, strict=True, include_raw=True, tools=[GenerateUsername]
)
expected_keys = {"raw", "parsing_error", "parsed"}
query = "Hello"
tool_query = "Generate a user name for Alice, black hair. Use the tool."
# Test invoke
## Engage structured output
response = llm.invoke(query)
assert isinstance(response["parsed"], ResponseFormat)
## Engage tool calling
response_tools = llm.invoke(tool_query)
ai_msg = response_tools["raw"]
assert isinstance(ai_msg, AIMessage)
assert ai_msg.tool_calls
assert response_tools["parsed"] is None
# Test stream
aggregated: dict = {}
for chunk in llm.stream(tool_query):
assert isinstance(chunk, dict)
assert all(key in expected_keys for key in chunk)
aggregated = {**aggregated, **chunk}
assert all(key in aggregated for key in expected_keys)
assert isinstance(aggregated["raw"], AIMessage)
assert aggregated["raw"].tool_calls
assert aggregated["parsed"] is None