mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
openai[patch]: ChatOpenAI.with_structured_output json_schema support (#25123)
This commit is contained in:
parent
0ba125c3cd
commit
09fbce13c5
@ -7,6 +7,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -27,7 +28,6 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
|
||||||
)
|
)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
|
||||||
from langchain_core.runnables.config import run_in_executor
|
from langchain_core.runnables.config import run_in_executor
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
@ -86,7 +86,11 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_function,
|
convert_to_openai_function,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
from langchain_core.utils.pydantic import (
|
||||||
|
PydanticBaseModel,
|
||||||
|
TypeBaseModel,
|
||||||
|
is_basemodel_subclass,
|
||||||
|
)
|
||||||
from langchain_core.utils.utils import build_extra_kwargs
|
from langchain_core.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -298,6 +302,8 @@ class _AllReturnType(TypedDict):
|
|||||||
class BaseChatOpenAI(BaseChatModel):
|
class BaseChatOpenAI(BaseChatModel):
|
||||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
|
root_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
|
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
@ -445,9 +451,8 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
) from e
|
) from e
|
||||||
values["http_client"] = httpx.Client(proxy=values["openai_proxy"])
|
values["http_client"] = httpx.Client(proxy=values["openai_proxy"])
|
||||||
sync_specific = {"http_client": values["http_client"]}
|
sync_specific = {"http_client": values["http_client"]}
|
||||||
values["client"] = openai.OpenAI(
|
values["root_client"] = openai.OpenAI(**client_params, **sync_specific)
|
||||||
**client_params, **sync_specific
|
values["client"] = values["root_client"].chat.completions
|
||||||
).chat.completions
|
|
||||||
if not values.get("async_client"):
|
if not values.get("async_client"):
|
||||||
if values["openai_proxy"] and not values["http_async_client"]:
|
if values["openai_proxy"] and not values["http_async_client"]:
|
||||||
try:
|
try:
|
||||||
@ -461,10 +466,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
proxy=values["openai_proxy"]
|
proxy=values["openai_proxy"]
|
||||||
)
|
)
|
||||||
async_specific = {"http_client": values["http_async_client"]}
|
async_specific = {"http_client": values["http_async_client"]}
|
||||||
values["async_client"] = openai.AsyncOpenAI(
|
values["root_async_client"] = openai.AsyncOpenAI(
|
||||||
**client_params, **async_specific
|
**client_params, **async_specific
|
||||||
).chat.completions
|
)
|
||||||
|
values["async_client"] = values["root_async_client"].chat.completions
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -525,13 +530,32 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
|
base_generation_info = {}
|
||||||
|
|
||||||
|
if "response_format" in payload and is_basemodel_subclass(
|
||||||
|
payload["response_format"]
|
||||||
|
):
|
||||||
|
# TODO: Add support for streaming with Pydantic response_format.
|
||||||
|
warnings.warn("Streaming with Pydantic response_format not yet supported.")
|
||||||
|
chat_result = self._generate(
|
||||||
|
messages, stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
msg = chat_result.generations[0].message
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(
|
||||||
|
**msg.dict(exclude={"type", "additional_kwargs"}),
|
||||||
|
# preserve the "parsed" Pydantic object without converting to dict
|
||||||
|
additional_kwargs=msg.additional_kwargs,
|
||||||
|
),
|
||||||
|
generation_info=chat_result.generations[0].generation_info,
|
||||||
|
)
|
||||||
|
return
|
||||||
if self.include_response_headers:
|
if self.include_response_headers:
|
||||||
raw_response = self.client.with_raw_response.create(**payload)
|
raw_response = self.client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = self.client.create(**payload)
|
response = self.client.create(**payload)
|
||||||
base_generation_info = {}
|
|
||||||
with response:
|
with response:
|
||||||
is_first_chunk = True
|
is_first_chunk = True
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
@ -594,13 +618,21 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
if self.include_response_headers:
|
generation_info = None
|
||||||
|
if "response_format" in payload:
|
||||||
|
if self.include_response_headers:
|
||||||
|
warnings.warn(
|
||||||
|
"Cannot currently include response headers when response_format is "
|
||||||
|
"specified."
|
||||||
|
)
|
||||||
|
payload.pop("stream")
|
||||||
|
response = self.root_client.beta.chat.completions.parse(**payload)
|
||||||
|
elif self.include_response_headers:
|
||||||
raw_response = self.client.with_raw_response.create(**payload)
|
raw_response = self.client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
generation_info = {"headers": dict(raw_response.headers)}
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = self.client.create(**payload)
|
response = self.client.create(**payload)
|
||||||
generation_info = None
|
|
||||||
return self._create_chat_result(response, generation_info)
|
return self._create_chat_result(response, generation_info)
|
||||||
|
|
||||||
def _get_request_payload(
|
def _get_request_payload(
|
||||||
@ -625,18 +657,19 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
generation_info: Optional[Dict] = None,
|
generation_info: Optional[Dict] = None,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
generations = []
|
generations = []
|
||||||
if not isinstance(response, dict):
|
|
||||||
response = response.model_dump()
|
|
||||||
|
|
||||||
|
response_dict = (
|
||||||
|
response if isinstance(response, dict) else response.model_dump()
|
||||||
|
)
|
||||||
# Sometimes the AI Model calling will get error, we should raise it.
|
# Sometimes the AI Model calling will get error, we should raise it.
|
||||||
# Otherwise, the next code 'choices.extend(response["choices"])'
|
# Otherwise, the next code 'choices.extend(response["choices"])'
|
||||||
# will throw a "TypeError: 'NoneType' object is not iterable" error
|
# will throw a "TypeError: 'NoneType' object is not iterable" error
|
||||||
# to mask the true error. Because 'response["choices"]' is None.
|
# to mask the true error. Because 'response["choices"]' is None.
|
||||||
if response.get("error"):
|
if response_dict.get("error"):
|
||||||
raise ValueError(response.get("error"))
|
raise ValueError(response_dict.get("error"))
|
||||||
|
|
||||||
token_usage = response.get("usage", {})
|
token_usage = response_dict.get("usage", {})
|
||||||
for res in response["choices"]:
|
for res in response_dict["choices"]:
|
||||||
message = _convert_dict_to_message(res["message"])
|
message = _convert_dict_to_message(res["message"])
|
||||||
if token_usage and isinstance(message, AIMessage):
|
if token_usage and isinstance(message, AIMessage):
|
||||||
message.usage_metadata = {
|
message.usage_metadata = {
|
||||||
@ -656,9 +689,19 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
generations.append(gen)
|
generations.append(gen)
|
||||||
llm_output = {
|
llm_output = {
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"model_name": response.get("model", self.model_name),
|
"model_name": response_dict.get("model", self.model_name),
|
||||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
"system_fingerprint": response_dict.get("system_fingerprint", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isinstance(response, openai.BaseModel) and getattr(
|
||||||
|
response, "choices", None
|
||||||
|
):
|
||||||
|
message = response.choices[0].message # type: ignore[attr-defined]
|
||||||
|
if hasattr(message, "parsed"):
|
||||||
|
generations[0].message.additional_kwargs["parsed"] = message.parsed
|
||||||
|
if hasattr(message, "refusal"):
|
||||||
|
generations[0].message.additional_kwargs["refusal"] = message.refusal
|
||||||
|
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
@ -671,13 +714,31 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
|
base_generation_info = {}
|
||||||
|
if "response_format" in payload and is_basemodel_subclass(
|
||||||
|
payload["response_format"]
|
||||||
|
):
|
||||||
|
# TODO: Add support for streaming with Pydantic response_format.
|
||||||
|
warnings.warn("Streaming with Pydantic response_format not yet supported.")
|
||||||
|
chat_result = await self._agenerate(
|
||||||
|
messages, stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
msg = chat_result.generations[0].message
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(
|
||||||
|
**msg.dict(exclude={"type", "additional_kwargs"}),
|
||||||
|
# preserve the "parsed" Pydantic object without converting to dict
|
||||||
|
additional_kwargs=msg.additional_kwargs,
|
||||||
|
),
|
||||||
|
generation_info=chat_result.generations[0].generation_info,
|
||||||
|
)
|
||||||
|
return
|
||||||
if self.include_response_headers:
|
if self.include_response_headers:
|
||||||
raw_response = self.async_client.with_raw_response.create(**payload)
|
raw_response = self.async_client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
base_generation_info = {"headers": dict(raw_response.headers)}
|
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = await self.async_client.create(**payload)
|
response = await self.async_client.create(**payload)
|
||||||
base_generation_info = {}
|
|
||||||
async with response:
|
async with response:
|
||||||
is_first_chunk = True
|
is_first_chunk = True
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
@ -745,13 +806,23 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return await agenerate_from_stream(stream_iter)
|
return await agenerate_from_stream(stream_iter)
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
if self.include_response_headers:
|
generation_info = None
|
||||||
|
if "response_format" in payload:
|
||||||
|
if self.include_response_headers:
|
||||||
|
warnings.warn(
|
||||||
|
"Cannot currently include response headers when response_format is "
|
||||||
|
"specified."
|
||||||
|
)
|
||||||
|
payload.pop("stream")
|
||||||
|
response = await self.root_async_client.beta.chat.completions.parse(
|
||||||
|
**payload
|
||||||
|
)
|
||||||
|
elif self.include_response_headers:
|
||||||
raw_response = await self.async_client.with_raw_response.create(**payload)
|
raw_response = await self.async_client.with_raw_response.create(**payload)
|
||||||
response = raw_response.parse()
|
response = raw_response.parse()
|
||||||
generation_info = {"headers": dict(raw_response.headers)}
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
else:
|
else:
|
||||||
response = await self.async_client.create(**payload)
|
response = await self.async_client.create(**payload)
|
||||||
generation_info = None
|
|
||||||
return await run_in_executor(
|
return await run_in_executor(
|
||||||
None, self._create_chat_result, response, generation_info
|
None, self._create_chat_result, response, generation_info
|
||||||
)
|
)
|
||||||
@ -1028,34 +1099,13 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
kwargs["tool_choice"] = tool_choice
|
kwargs["tool_choice"] = tool_choice
|
||||||
return super().bind(tools=formatted_tools, **kwargs)
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
# TODO: Fix typing.
|
|
||||||
@overload # type: ignore[override]
|
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self,
|
self,
|
||||||
schema: Optional[_DictOrPydanticClass] = None,
|
schema: Optional[_DictOrPydanticClass] = None,
|
||||||
*,
|
*,
|
||||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
method: Literal[
|
||||||
include_raw: Literal[True] = True,
|
"function_calling", "json_mode", "json_schema"
|
||||||
strict: Optional[bool] = None,
|
] = "function_calling",
|
||||||
**kwargs: Any,
|
|
||||||
) -> Runnable[LanguageModelInput, _AllReturnType]: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def with_structured_output(
|
|
||||||
self,
|
|
||||||
schema: Optional[_DictOrPydanticClass] = None,
|
|
||||||
*,
|
|
||||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
|
||||||
include_raw: Literal[False] = False,
|
|
||||||
strict: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ...
|
|
||||||
|
|
||||||
def with_structured_output(
|
|
||||||
self,
|
|
||||||
schema: Optional[_DictOrPydanticClass] = None,
|
|
||||||
*,
|
|
||||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
|
||||||
include_raw: bool = False,
|
include_raw: bool = False,
|
||||||
strict: Optional[bool] = None,
|
strict: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -1065,10 +1115,12 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
.. versionchanged:: 0.1.21
|
.. versionchanged:: 0.1.21
|
||||||
|
|
||||||
Support for ``strict`` argument added.
|
Support for ``strict`` argument added.
|
||||||
|
Support for ``method`` = "json_schema" added.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema:
|
schema:
|
||||||
The output schema. Can be passed in as:
|
The output schema. Can be passed in as:
|
||||||
|
|
||||||
- an OpenAI function/tool schema,
|
- an OpenAI function/tool schema,
|
||||||
- a JSON Schema,
|
- a JSON Schema,
|
||||||
- a TypedDict class (support added in 0.1.20),
|
- a TypedDict class (support added in 0.1.20),
|
||||||
@ -1085,12 +1137,36 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
Added support for TypedDict class.
|
Added support for TypedDict class.
|
||||||
|
|
||||||
method:
|
method:
|
||||||
The method for steering model generation, one of "function_calling"
|
The method for steering model generation, one of:
|
||||||
or "json_mode". If "function_calling" then the schema will be converted
|
|
||||||
to an OpenAI function and the returned model will make use of the
|
- "function_calling":
|
||||||
function-calling API. If "json_mode" then OpenAI's JSON mode will be
|
Uses OpenAI's tool-calling (formerly called function calling)
|
||||||
used. Note that if using "json_mode" then you must include instructions
|
API: https://platform.openai.com/docs/guides/function-calling
|
||||||
for formatting the output into the desired schema into the model call.
|
- "json_schema":
|
||||||
|
Uses OpenAI's Structured Output API:
|
||||||
|
https://platform.openai.com/docs/guides/structured-outputs.
|
||||||
|
Supported for "gpt-4o-mini", "gpt-4o-2024-08-06", and later
|
||||||
|
models.
|
||||||
|
- "json_mode":
|
||||||
|
Uses OpenAI's JSON mode. Note that if using JSON mode then you
|
||||||
|
must include instructions for formatting the output into the
|
||||||
|
desired schema into the model call:
|
||||||
|
https://platform.openai.com/docs/guides/structured-outputs/json-mode
|
||||||
|
|
||||||
|
Learn more about the differences between the methods and which models
|
||||||
|
support which methods here:
|
||||||
|
|
||||||
|
- https://platform.openai.com/docs/guides/structured-outputs/structured-outputs-vs-json-mode
|
||||||
|
- https://platform.openai.com/docs/guides/structured-outputs/function-calling-vs-response-format
|
||||||
|
|
||||||
|
.. versionchanged:: 0.1.21
|
||||||
|
|
||||||
|
Added support for "json_schema".
|
||||||
|
|
||||||
|
.. note:: Planned breaking change in version `0.2.0`
|
||||||
|
|
||||||
|
``method`` default will be changed to "json_schema" from
|
||||||
|
"function_calling".
|
||||||
include_raw:
|
include_raw:
|
||||||
If False then only the parsed structured output is returned. If
|
If False then only the parsed structured output is returned. If
|
||||||
an error occurs during model output parsing it will be raised. If True
|
an error occurs during model output parsing it will be raised. If True
|
||||||
@ -1098,14 +1174,20 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
response will be returned. If an error occurs during output parsing it
|
response will be returned. If an error occurs during output parsing it
|
||||||
will be caught and returned as well. The final output is always a dict
|
will be caught and returned as well. The final output is always a dict
|
||||||
with keys "raw", "parsed", and "parsing_error".
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
strict: If True and ``method`` = "function_calling", model output is
|
strict:
|
||||||
guaranteed to exactly match the schema
|
- True:
|
||||||
If True, the input schema will also be
|
Model output is guaranteed to exactly match the schema.
|
||||||
validated according to
|
The input schema will also be validated according to
|
||||||
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
|
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
|
||||||
If False, input schema will not be validated and model output will not
|
- False:
|
||||||
be validated.
|
Input schema will not be validated and model output will not be
|
||||||
If None, ``strict`` argument will not be passed to the model.
|
validated.
|
||||||
|
- None:
|
||||||
|
``strict`` argument will not be passed to the model.
|
||||||
|
|
||||||
|
If ``method`` is "json_schema" defaults to True. If ``method`` is
|
||||||
|
"function_calling" or "json_mode" defaults to None. Can only be
|
||||||
|
non-null if ``method`` is "function_calling" or "json_schema".
|
||||||
|
|
||||||
.. versionadded:: 0.1.21
|
.. versionadded:: 0.1.21
|
||||||
|
|
||||||
@ -1124,9 +1206,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||||
|
|
||||||
If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||||
- ``"raw"``: BaseMessage
|
|
||||||
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
- "raw": BaseMessage
|
||||||
- ``"parsing_error"``: Optional[BaseException]
|
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||||
|
- "parsing_error": Optional[BaseException]
|
||||||
|
|
||||||
Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True:
|
Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True:
|
||||||
.. note:: Valid schemas when using ``strict`` = True
|
.. note:: Valid schemas when using ``strict`` = True
|
||||||
@ -1305,15 +1388,15 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
if strict is not None and method != "function_calling":
|
if strict is not None and method == "json_mode":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Argument `strict` is only supported for `method`='function_calling'"
|
"Argument `strict` is not supported with `method`='json_mode'"
|
||||||
)
|
)
|
||||||
is_pydantic_schema = _is_pydantic_class(schema)
|
is_pydantic_schema = _is_pydantic_class(schema)
|
||||||
if method == "function_calling":
|
if method == "function_calling":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"schema must be specified when method is 'function_calling'. "
|
"schema must be specified when method is not 'json_mode'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
@ -1339,6 +1422,20 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
else JsonOutputParser()
|
else JsonOutputParser()
|
||||||
)
|
)
|
||||||
|
elif method == "json_schema":
|
||||||
|
if schema is None:
|
||||||
|
raise ValueError(
|
||||||
|
"schema must be specified when method is not 'json_mode'. "
|
||||||
|
"Received None."
|
||||||
|
)
|
||||||
|
strict = strict if strict is not None else True
|
||||||
|
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||||
|
llm = self.bind(response_format=response_format)
|
||||||
|
output_parser = (
|
||||||
|
cast(Runnable, _oai_structured_outputs_parser)
|
||||||
|
if is_pydantic_schema
|
||||||
|
else JsonOutputParser()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||||
@ -1975,3 +2072,40 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
|
|||||||
height = (width * 768) // height
|
height = (width * 768) // height
|
||||||
width = 768
|
width = 768
|
||||||
return width, height
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_openai_response_format(
|
||||||
|
schema: Union[Dict[str, Any], Type], strict: bool
|
||||||
|
) -> Union[Dict, TypeBaseModel]:
|
||||||
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
|
return schema
|
||||||
|
else:
|
||||||
|
function = convert_to_openai_function(schema, strict=strict)
|
||||||
|
function["schema"] = function.pop("parameters")
|
||||||
|
return {"type": "json_schema", "json_schema": function}
|
||||||
|
|
||||||
|
|
||||||
|
@chain
|
||||||
|
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
|
||||||
|
if ai_msg.additional_kwargs.get("parsed"):
|
||||||
|
return ai_msg.additional_kwargs["parsed"]
|
||||||
|
elif ai_msg.additional_kwargs.get("refusal"):
|
||||||
|
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
|
||||||
|
"field."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIRefusalError(Exception):
|
||||||
|
"""Error raised when OpenAI Structured Outputs API returns a refusal.
|
||||||
|
|
||||||
|
When using OpenAI's Structured Outputs API with user-generated input, the model
|
||||||
|
may occasionally refuse to fulfill the request for safety reasons.
|
||||||
|
|
||||||
|
See here for more on refusals:
|
||||||
|
https://platform.openai.com/docs/guides/structured-outputs/refusals
|
||||||
|
|
||||||
|
.. versionadded:: 0.1.21
|
||||||
|
"""
|
||||||
|
2
libs/partners/openai/poetry.lock
generated
2
libs/partners/openai/poetry.lock
generated
@ -1527,4 +1527,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "23d99a41f0cff5bf1869e8e18ac953a9802b0d1912eedcecca624650c1ff3af6"
|
content-hash = "a08bed7f2e62b3f6c7fc52a31c2529b44d4e5adcc55aba5047be027596fdb31f"
|
||||||
|
@ -24,7 +24,7 @@ ignore_missing_imports = true
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = { version = "^0.2.29rc1", allow-prereleases=true }
|
langchain-core = { version = "^0.2.29rc1", allow-prereleases=true }
|
||||||
openai = "^1.32.0"
|
openai = "^1.40.0"
|
||||||
tiktoken = ">=0.7,<1"
|
tiktoken = ">=0.7,<1"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test ChatOpenAI chat model."""
|
"""Test ChatOpenAI chat model."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from typing import Any, AsyncIterator, List, Optional, cast
|
from typing import Any, AsyncIterator, List, Literal, Optional, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
@ -796,13 +796,21 @@ def test_tool_calling_strict() -> None:
|
|||||||
next(model_with_invalid_tool_schema.stream(query))
|
next(model_with_invalid_tool_schema.stream(query))
|
||||||
|
|
||||||
|
|
||||||
def test_structured_output_strict() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
("model", "method", "strict"),
|
||||||
|
[("gpt-4o", "function_calling", True), ("gpt-4o-2024-08-06", "json_schema", None)],
|
||||||
|
)
|
||||||
|
def test_structured_output_strict(
|
||||||
|
model: str,
|
||||||
|
method: Literal["function_calling", "json_schema"],
|
||||||
|
strict: Optional[bool],
|
||||||
|
) -> None:
|
||||||
"""Test to verify structured output with strict=True."""
|
"""Test to verify structured output with strict=True."""
|
||||||
|
|
||||||
from pydantic import BaseModel as BaseModelProper
|
from pydantic import BaseModel as BaseModelProper
|
||||||
from pydantic import Field as FieldProper
|
from pydantic import Field as FieldProper
|
||||||
|
|
||||||
model = ChatOpenAI(model="gpt-4o", temperature=0)
|
llm = ChatOpenAI(model=model, temperature=0)
|
||||||
|
|
||||||
class Joke(BaseModelProper):
|
class Joke(BaseModelProper):
|
||||||
"""Joke to tell user."""
|
"""Joke to tell user."""
|
||||||
@ -814,7 +822,7 @@ def test_structured_output_strict() -> None:
|
|||||||
# Type ignoring since the interface only officially supports pydantic 1
|
# Type ignoring since the interface only officially supports pydantic 1
|
||||||
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
|
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
|
||||||
# We'll need to do a pass updating the type signatures.
|
# We'll need to do a pass updating the type signatures.
|
||||||
chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type]
|
chat = llm.with_structured_output(Joke, method=method, strict=strict)
|
||||||
result = chat.invoke("Tell me a joke about cats.")
|
result = chat.invoke("Tell me a joke about cats.")
|
||||||
assert isinstance(result, Joke)
|
assert isinstance(result, Joke)
|
||||||
|
|
||||||
@ -822,7 +830,9 @@ def test_structured_output_strict() -> None:
|
|||||||
assert isinstance(chunk, Joke)
|
assert isinstance(chunk, Joke)
|
||||||
|
|
||||||
# Schema
|
# Schema
|
||||||
chat = model.with_structured_output(Joke.model_json_schema(), strict=True)
|
chat = llm.with_structured_output(
|
||||||
|
Joke.model_json_schema(), method=method, strict=strict
|
||||||
|
)
|
||||||
result = chat.invoke("Tell me a joke about cats.")
|
result = chat.invoke("Tell me a joke about cats.")
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert set(result.keys()) == {"setup", "punchline"}
|
assert set(result.keys()) == {"setup", "punchline"}
|
||||||
@ -831,3 +841,27 @@ def test_structured_output_strict() -> None:
|
|||||||
assert isinstance(chunk, dict)
|
assert isinstance(chunk, dict)
|
||||||
assert isinstance(chunk, dict) # for mypy
|
assert isinstance(chunk, dict) # for mypy
|
||||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||||
|
|
||||||
|
# Invalid schema with optional fields:
|
||||||
|
class InvalidJoke(BaseModelProper):
|
||||||
|
"""Joke to tell user."""
|
||||||
|
|
||||||
|
setup: str = FieldProper(description="question to set up a joke")
|
||||||
|
# Invalid field, can't have default value.
|
||||||
|
punchline: str = FieldProper(
|
||||||
|
default="foo", description="answer to resolve the joke"
|
||||||
|
)
|
||||||
|
|
||||||
|
chat = llm.with_structured_output(InvalidJoke, method=method, strict=strict)
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
chat.invoke("Tell me a joke about cats.")
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
next(chat.stream("Tell me a joke about cats."))
|
||||||
|
|
||||||
|
chat = llm.with_structured_output(
|
||||||
|
InvalidJoke.model_json_schema(), method=method, strict=strict
|
||||||
|
)
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
chat.invoke("Tell me a joke about cats.")
|
||||||
|
with pytest.raises(openai.BadRequestError):
|
||||||
|
next(chat.stream("Tell me a joke about cats."))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test OpenAI Chat API wrapper."""
|
"""Test OpenAI Chat API wrapper."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, List, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -343,17 +343,32 @@ class MakeASandwich(BaseModel):
|
|||||||
None,
|
None,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_bind_tools_tool_choice(tool_choice: Any) -> None:
|
@pytest.mark.parametrize("strict", [True, False, None])
|
||||||
|
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None:
|
||||||
"""Test passing in manually construct tool call message."""
|
"""Test passing in manually construct tool call message."""
|
||||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||||
llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice)
|
llm.bind_tools(
|
||||||
|
tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice, strict=strict
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
|
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
|
||||||
def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
|
@pytest.mark.parametrize("method", ["json_schema", "function_calling", "json_mode"])
|
||||||
|
@pytest.mark.parametrize("include_raw", [True, False])
|
||||||
|
@pytest.mark.parametrize("strict", [True, False, None])
|
||||||
|
def test_with_structured_output(
|
||||||
|
schema: Union[Type, Dict[str, Any], None],
|
||||||
|
method: Literal["function_calling", "json_mode", "json_schema"],
|
||||||
|
include_raw: bool,
|
||||||
|
strict: Optional[bool],
|
||||||
|
) -> None:
|
||||||
"""Test passing in manually construct tool call message."""
|
"""Test passing in manually construct tool call message."""
|
||||||
|
if method == "json_mode":
|
||||||
|
strict = None
|
||||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||||
llm.with_structured_output(schema)
|
llm.with_structured_output(
|
||||||
|
schema, method=method, strict=strict, include_raw=include_raw
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_num_tokens_from_messages() -> None:
|
def test_get_num_tokens_from_messages() -> None:
|
||||||
|
1157
poetry.lock
generated
1157
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user