mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
openai[patch]: support structured output and tools (#30581)
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
32f7695809
commit
111dd90a46
@ -1477,6 +1477,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
strict: Optional[bool] = None,
|
||||
tools: Optional[list] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
@ -1537,6 +1538,51 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
- None:
|
||||
``strict`` argument will not be passed to the model.
|
||||
|
||||
tools:
|
||||
A list of tool-like objects to bind to the chat model. Requires that:
|
||||
|
||||
- ``method`` is ``"json_schema"`` (default).
|
||||
- ``strict=True``
|
||||
- ``include_raw=True``
|
||||
|
||||
If a model elects to call a
|
||||
tool, the resulting ``AIMessage`` in ``"raw"`` will include tool calls.
|
||||
|
||||
.. dropdown:: Example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
def get_weather(location: str) -> str:
|
||||
\"\"\"Get weather at a location.\"\"\"
|
||||
pass
|
||||
|
||||
llm = init_chat_model("openai:gpt-4o-mini")
|
||||
|
||||
structured_llm = llm.with_structured_output(
|
||||
ResponseSchema,
|
||||
tools=[get_weather],
|
||||
strict=True,
|
||||
include_raw=True,
|
||||
)
|
||||
|
||||
structured_llm.invoke("What's the weather in Boston?")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"raw": AIMessage(content="", tool_calls=[...], ...),
|
||||
"parsing_error": None,
|
||||
"parsed": None,
|
||||
}
|
||||
|
||||
kwargs: Additional keyword args aren't supported.
|
||||
|
||||
Returns:
|
||||
@ -1558,6 +1604,9 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
|
||||
Support for ``strict`` argument added.
|
||||
Support for ``method`` = "json_schema" added.
|
||||
|
||||
.. versionchanged:: 0.3.12
|
||||
Support for ``tools`` added.
|
||||
""" # noqa: E501
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
@ -1642,13 +1691,18 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
"Received None."
|
||||
)
|
||||
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||
llm = self.bind(
|
||||
bind_kwargs = dict(
|
||||
response_format=response_format,
|
||||
ls_structured_output_format={
|
||||
"kwargs": {"method": method, "strict": strict},
|
||||
"schema": convert_to_openai_tool(schema),
|
||||
},
|
||||
)
|
||||
if tools:
|
||||
bind_kwargs["tools"] = [
|
||||
convert_to_openai_tool(t, strict=strict) for t in tools
|
||||
]
|
||||
llm = self.bind(**bind_kwargs)
|
||||
if is_pydantic_schema:
|
||||
output_parser = RunnableLambda(
|
||||
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
|
||||
@ -2776,7 +2830,7 @@ def _convert_to_openai_response_format(
|
||||
|
||||
def _oai_structured_outputs_parser(
|
||||
ai_msg: AIMessage, schema: Type[_BM]
|
||||
) -> PydanticBaseModel:
|
||||
) -> Optional[PydanticBaseModel]:
|
||||
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
||||
if isinstance(parsed, dict):
|
||||
return schema(**parsed)
|
||||
@ -2784,6 +2838,8 @@ def _oai_structured_outputs_parser(
|
||||
return parsed
|
||||
elif ai_msg.additional_kwargs.get("refusal"):
|
||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||
elif ai_msg.tool_calls:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
||||
|
@ -1265,3 +1265,38 @@ def test_structured_output_and_tools() -> None:
|
||||
assert len(full.tool_calls) == 1
|
||||
tool_call = full.tool_calls[0]
|
||||
assert tool_call["name"] == "GenerateUsername"
|
||||
|
||||
|
||||
def test_tools_and_structured_output() -> None:
|
||||
class ResponseFormat(BaseModel):
|
||||
response: str
|
||||
explanation: str
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini").with_structured_output(
|
||||
ResponseFormat, strict=True, include_raw=True, tools=[GenerateUsername]
|
||||
)
|
||||
|
||||
expected_keys = {"raw", "parsing_error", "parsed"}
|
||||
query = "Hello"
|
||||
tool_query = "Generate a user name for Alice, black hair. Use the tool."
|
||||
# Test invoke
|
||||
## Engage structured output
|
||||
response = llm.invoke(query)
|
||||
assert isinstance(response["parsed"], ResponseFormat)
|
||||
## Engage tool calling
|
||||
response_tools = llm.invoke(tool_query)
|
||||
ai_msg = response_tools["raw"]
|
||||
assert isinstance(ai_msg, AIMessage)
|
||||
assert ai_msg.tool_calls
|
||||
assert response_tools["parsed"] is None
|
||||
|
||||
# Test stream
|
||||
aggregated: dict = {}
|
||||
for chunk in llm.stream(tool_query):
|
||||
assert isinstance(chunk, dict)
|
||||
assert all(key in expected_keys for key in chunk)
|
||||
aggregated = {**aggregated, **chunk}
|
||||
assert all(key in aggregated for key in expected_keys)
|
||||
assert isinstance(aggregated["raw"], AIMessage)
|
||||
assert aggregated["raw"].tool_calls
|
||||
assert aggregated["parsed"] is None
|
||||
|
Loading…
Reference in New Issue
Block a user