mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
openai[patch]: support structured output and tools (#30581)
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
32f7695809
commit
111dd90a46
@ -1477,6 +1477,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
] = "function_calling",
|
] = "function_calling",
|
||||||
include_raw: bool = False,
|
include_raw: bool = False,
|
||||||
strict: Optional[bool] = None,
|
strict: Optional[bool] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
@ -1537,6 +1538,51 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
- None:
|
- None:
|
||||||
``strict`` argument will not be passed to the model.
|
``strict`` argument will not be passed to the model.
|
||||||
|
|
||||||
|
tools:
|
||||||
|
A list of tool-like objects to bind to the chat model. Requires that:
|
||||||
|
|
||||||
|
- ``method`` is ``"json_schema"`` (default).
|
||||||
|
- ``strict=True``
|
||||||
|
- ``include_raw=True``
|
||||||
|
|
||||||
|
If a model elects to call a
|
||||||
|
tool, the resulting ``AIMessage`` in ``"raw"`` will include tool calls.
|
||||||
|
|
||||||
|
.. dropdown:: Example
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import init_chat_model
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseSchema(BaseModel):
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_weather(location: str) -> str:
|
||||||
|
\"\"\"Get weather at a location.\"\"\"
|
||||||
|
pass
|
||||||
|
|
||||||
|
llm = init_chat_model("openai:gpt-4o-mini")
|
||||||
|
|
||||||
|
structured_llm = llm.with_structured_output(
|
||||||
|
ResponseSchema,
|
||||||
|
tools=[get_weather],
|
||||||
|
strict=True,
|
||||||
|
include_raw=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
structured_llm.invoke("What's the weather in Boston?")
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{
|
||||||
|
"raw": AIMessage(content="", tool_calls=[...], ...),
|
||||||
|
"parsing_error": None,
|
||||||
|
"parsed": None,
|
||||||
|
}
|
||||||
|
|
||||||
kwargs: Additional keyword args aren't supported.
|
kwargs: Additional keyword args aren't supported.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1558,6 +1604,9 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
Support for ``strict`` argument added.
|
Support for ``strict`` argument added.
|
||||||
Support for ``method`` = "json_schema" added.
|
Support for ``method`` = "json_schema" added.
|
||||||
|
|
||||||
|
.. versionchanged:: 0.3.12
|
||||||
|
Support for ``tools`` added.
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
@ -1642,13 +1691,18 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||||
llm = self.bind(
|
bind_kwargs = dict(
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
ls_structured_output_format={
|
ls_structured_output_format={
|
||||||
"kwargs": {"method": method, "strict": strict},
|
"kwargs": {"method": method, "strict": strict},
|
||||||
"schema": convert_to_openai_tool(schema),
|
"schema": convert_to_openai_tool(schema),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if tools:
|
||||||
|
bind_kwargs["tools"] = [
|
||||||
|
convert_to_openai_tool(t, strict=strict) for t in tools
|
||||||
|
]
|
||||||
|
llm = self.bind(**bind_kwargs)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser = RunnableLambda(
|
output_parser = RunnableLambda(
|
||||||
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
|
partial(_oai_structured_outputs_parser, schema=cast(type, schema))
|
||||||
@ -2776,7 +2830,7 @@ def _convert_to_openai_response_format(
|
|||||||
|
|
||||||
def _oai_structured_outputs_parser(
|
def _oai_structured_outputs_parser(
|
||||||
ai_msg: AIMessage, schema: Type[_BM]
|
ai_msg: AIMessage, schema: Type[_BM]
|
||||||
) -> PydanticBaseModel:
|
) -> Optional[PydanticBaseModel]:
|
||||||
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
if parsed := ai_msg.additional_kwargs.get("parsed"):
|
||||||
if isinstance(parsed, dict):
|
if isinstance(parsed, dict):
|
||||||
return schema(**parsed)
|
return schema(**parsed)
|
||||||
@ -2784,6 +2838,8 @@ def _oai_structured_outputs_parser(
|
|||||||
return parsed
|
return parsed
|
||||||
elif ai_msg.additional_kwargs.get("refusal"):
|
elif ai_msg.additional_kwargs.get("refusal"):
|
||||||
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||||
|
elif ai_msg.tool_calls:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
||||||
|
@ -1265,3 +1265,38 @@ def test_structured_output_and_tools() -> None:
|
|||||||
assert len(full.tool_calls) == 1
|
assert len(full.tool_calls) == 1
|
||||||
tool_call = full.tool_calls[0]
|
tool_call = full.tool_calls[0]
|
||||||
assert tool_call["name"] == "GenerateUsername"
|
assert tool_call["name"] == "GenerateUsername"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_and_structured_output() -> None:
|
||||||
|
class ResponseFormat(BaseModel):
|
||||||
|
response: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
llm = ChatOpenAI(model="gpt-4o-mini").with_structured_output(
|
||||||
|
ResponseFormat, strict=True, include_raw=True, tools=[GenerateUsername]
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_keys = {"raw", "parsing_error", "parsed"}
|
||||||
|
query = "Hello"
|
||||||
|
tool_query = "Generate a user name for Alice, black hair. Use the tool."
|
||||||
|
# Test invoke
|
||||||
|
## Engage structured output
|
||||||
|
response = llm.invoke(query)
|
||||||
|
assert isinstance(response["parsed"], ResponseFormat)
|
||||||
|
## Engage tool calling
|
||||||
|
response_tools = llm.invoke(tool_query)
|
||||||
|
ai_msg = response_tools["raw"]
|
||||||
|
assert isinstance(ai_msg, AIMessage)
|
||||||
|
assert ai_msg.tool_calls
|
||||||
|
assert response_tools["parsed"] is None
|
||||||
|
|
||||||
|
# Test stream
|
||||||
|
aggregated: dict = {}
|
||||||
|
for chunk in llm.stream(tool_query):
|
||||||
|
assert isinstance(chunk, dict)
|
||||||
|
assert all(key in expected_keys for key in chunk)
|
||||||
|
aggregated = {**aggregated, **chunk}
|
||||||
|
assert all(key in aggregated for key in expected_keys)
|
||||||
|
assert isinstance(aggregated["raw"], AIMessage)
|
||||||
|
assert aggregated["raw"].tool_calls
|
||||||
|
assert aggregated["parsed"] is None
|
||||||
|
Loading…
Reference in New Issue
Block a user