openai[patch]: ChatOpenAI.with_structured_output json_schema support (#25123)

This commit is contained in:
Bagatur 2024-08-07 08:09:07 -07:00 committed by GitHub
parent 0ba125c3cd
commit 09fbce13c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 895 additions and 603 deletions

View File

@ -7,6 +7,7 @@ import json
import logging import logging
import os import os
import sys import sys
import warnings
from io import BytesIO from io import BytesIO
from math import ceil from math import ceil
from operator import itemgetter from operator import itemgetter
@ -27,7 +28,6 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
cast, cast,
overload,
) )
from urllib.parse import urlparse from urllib.parse import urlparse
@ -74,7 +74,7 @@ from langchain_core.output_parsers.openai_tools import (
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import (
@ -86,7 +86,11 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import (
PydanticBaseModel,
TypeBaseModel,
is_basemodel_subclass,
)
from langchain_core.utils.utils import build_extra_kwargs from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -298,6 +302,8 @@ class _AllReturnType(TypedDict):
class BaseChatOpenAI(BaseChatModel): class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private: client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private:
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model") model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use.""" """Model name to use."""
temperature: float = 0.7 temperature: float = 0.7
@ -445,9 +451,8 @@ class BaseChatOpenAI(BaseChatModel):
) from e ) from e
values["http_client"] = httpx.Client(proxy=values["openai_proxy"]) values["http_client"] = httpx.Client(proxy=values["openai_proxy"])
sync_specific = {"http_client": values["http_client"]} sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI( values["root_client"] = openai.OpenAI(**client_params, **sync_specific)
**client_params, **sync_specific values["client"] = values["root_client"].chat.completions
).chat.completions
if not values.get("async_client"): if not values.get("async_client"):
if values["openai_proxy"] and not values["http_async_client"]: if values["openai_proxy"] and not values["http_async_client"]:
try: try:
@ -461,10 +466,10 @@ class BaseChatOpenAI(BaseChatModel):
proxy=values["openai_proxy"] proxy=values["openai_proxy"]
) )
async_specific = {"http_client": values["http_async_client"]} async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI( values["root_async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific **client_params, **async_specific
).chat.completions )
values["async_client"] = values["root_async_client"].chat.completions
return values return values
@property @property
@ -525,13 +530,32 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = self._generate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers: if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload) raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)} base_generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = self.client.create(**payload) response = self.client.create(**payload)
base_generation_info = {}
with response: with response:
is_first_chunk = True is_first_chunk = True
for chunk in response: for chunk in response:
@ -594,13 +618,21 @@ class BaseChatOpenAI(BaseChatModel):
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
generation_info = None
if "response_format" in payload:
if self.include_response_headers: if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response = self.root_client.beta.chat.completions.parse(**payload)
elif self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload) raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = self.client.create(**payload) response = self.client.create(**payload)
generation_info = None
return self._create_chat_result(response, generation_info) return self._create_chat_result(response, generation_info)
def _get_request_payload( def _get_request_payload(
@ -625,18 +657,19 @@ class BaseChatOpenAI(BaseChatModel):
generation_info: Optional[Dict] = None, generation_info: Optional[Dict] = None,
) -> ChatResult: ) -> ChatResult:
generations = [] generations = []
if not isinstance(response, dict):
response = response.model_dump()
response_dict = (
response if isinstance(response, dict) else response.model_dump()
)
# Sometimes the AI Model calling will get error, we should raise it. # Sometimes the AI Model calling will get error, we should raise it.
# Otherwise, the next code 'choices.extend(response["choices"])' # Otherwise, the next code 'choices.extend(response["choices"])'
# will throw a "TypeError: 'NoneType' object is not iterable" error # will throw a "TypeError: 'NoneType' object is not iterable" error
# to mask the true error. Because 'response["choices"]' is None. # to mask the true error. Because 'response["choices"]' is None.
if response.get("error"): if response_dict.get("error"):
raise ValueError(response.get("error")) raise ValueError(response_dict.get("error"))
token_usage = response.get("usage", {}) token_usage = response_dict.get("usage", {})
for res in response["choices"]: for res in response_dict["choices"]:
message = _convert_dict_to_message(res["message"]) message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage): if token_usage and isinstance(message, AIMessage):
message.usage_metadata = { message.usage_metadata = {
@ -656,9 +689,19 @@ class BaseChatOpenAI(BaseChatModel):
generations.append(gen) generations.append(gen)
llm_output = { llm_output = {
"token_usage": token_usage, "token_usage": token_usage,
"model_name": response.get("model", self.model_name), "model_name": response_dict.get("model", self.model_name),
"system_fingerprint": response.get("system_fingerprint", ""), "system_fingerprint": response_dict.get("system_fingerprint", ""),
} }
if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
):
message = response.choices[0].message # type: ignore[attr-defined]
if hasattr(message, "parsed"):
generations[0].message.additional_kwargs["parsed"] = message.parsed
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
async def _astream( async def _astream(
@ -671,13 +714,31 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = await self._agenerate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers: if self.include_response_headers:
raw_response = self.async_client.with_raw_response.create(**payload) raw_response = self.async_client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)} base_generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = await self.async_client.create(**payload) response = await self.async_client.create(**payload)
base_generation_info = {}
async with response: async with response:
is_first_chunk = True is_first_chunk = True
async for chunk in response: async for chunk in response:
@ -745,13 +806,23 @@ class BaseChatOpenAI(BaseChatModel):
) )
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
generation_info = None
if "response_format" in payload:
if self.include_response_headers: if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response = await self.root_async_client.beta.chat.completions.parse(
**payload
)
elif self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload) raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse() response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)} generation_info = {"headers": dict(raw_response.headers)}
else: else:
response = await self.async_client.create(**payload) response = await self.async_client.create(**payload)
generation_info = None
return await run_in_executor( return await run_in_executor(
None, self._create_chat_result, response, generation_info None, self._create_chat_result, response, generation_info
) )
@ -1028,34 +1099,13 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["tool_choice"] = tool_choice kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs) return super().bind(tools=formatted_tools, **kwargs)
# TODO: Fix typing.
@overload # type: ignore[override]
def with_structured_output( def with_structured_output(
self, self,
schema: Optional[_DictOrPydanticClass] = None, schema: Optional[_DictOrPydanticClass] = None,
*, *,
method: Literal["function_calling", "json_mode"] = "function_calling", method: Literal[
include_raw: Literal[True] = True, "function_calling", "json_mode", "json_schema"
strict: Optional[bool] = None, ] = "function_calling",
**kwargs: Any,
) -> Runnable[LanguageModelInput, _AllReturnType]: ...
@overload
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: Literal[False] = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ...
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False, include_raw: bool = False,
strict: Optional[bool] = None, strict: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
@ -1065,10 +1115,12 @@ class BaseChatOpenAI(BaseChatModel):
.. versionchanged:: 0.1.21 .. versionchanged:: 0.1.21
Support for ``strict`` argument added. Support for ``strict`` argument added.
Support for ``method`` = "json_schema" added.
Args: Args:
schema: schema:
The output schema. Can be passed in as: The output schema. Can be passed in as:
- an OpenAI function/tool schema, - an OpenAI function/tool schema,
- a JSON Schema, - a JSON Schema,
- a TypedDict class (support added in 0.1.20), - a TypedDict class (support added in 0.1.20),
@ -1085,12 +1137,36 @@ class BaseChatOpenAI(BaseChatModel):
Added support for TypedDict class. Added support for TypedDict class.
method: method:
The method for steering model generation, one of "function_calling" The method for steering model generation, one of:
or "json_mode". If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the - "function_calling":
function-calling API. If "json_mode" then OpenAI's JSON mode will be Uses OpenAI's tool-calling (formerly called function calling)
used. Note that if using "json_mode" then you must include instructions API: https://platform.openai.com/docs/guides/function-calling
for formatting the output into the desired schema into the model call. - "json_schema":
Uses OpenAI's Structured Output API:
https://platform.openai.com/docs/guides/structured-outputs.
Supported for "gpt-4o-mini", "gpt-4o-2024-08-06", and later
models.
- "json_mode":
Uses OpenAI's JSON mode. Note that if using JSON mode then you
must include instructions for formatting the output into the
desired schema into the model call:
https://platform.openai.com/docs/guides/structured-outputs/json-mode
Learn more about the differences between the methods and which models
support which methods here:
- https://platform.openai.com/docs/guides/structured-outputs/structured-outputs-vs-json-mode
- https://platform.openai.com/docs/guides/structured-outputs/function-calling-vs-response-format
.. versionchanged:: 0.1.21
Added support for "json_schema".
.. note:: Planned breaking change in version `0.2.0`
``method`` default will be changed to "json_schema" from
"function_calling".
include_raw: include_raw:
If False then only the parsed structured output is returned. If If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True an error occurs during model output parsing it will be raised. If True
@ -1098,14 +1174,20 @@ class BaseChatOpenAI(BaseChatModel):
response will be returned. If an error occurs during output parsing it response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error". with keys "raw", "parsed", and "parsing_error".
strict: If True and ``method`` = "function_calling", model output is strict:
guaranteed to exactly match the schema - True:
If True, the input schema will also be Model output is guaranteed to exactly match the schema.
validated according to The input schema will also be validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas. https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
If False, input schema will not be validated and model output will not - False:
be validated. Input schema will not be validated and model output will not be
If None, ``strict`` argument will not be passed to the model. validated.
- None:
``strict`` argument will not be passed to the model.
If ``method`` is "json_schema" defaults to True. If ``method`` is
"function_calling" or "json_mode" defaults to None. Can only be
non-null if ``method`` is "function_calling" or "json_schema".
.. versionadded:: 0.1.21 .. versionadded:: 0.1.21
@ -1124,9 +1206,10 @@ class BaseChatOpenAI(BaseChatModel):
Otherwise, if ``include_raw`` is False then Runnable outputs a dict. Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
If ``include_raw`` is True, then Runnable outputs a dict with keys: If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. - "raw": BaseMessage
- ``"parsing_error"``: Optional[BaseException] - "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True: Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True:
.. note:: Valid schemas when using ``strict`` = True .. note:: Valid schemas when using ``strict`` = True
@ -1305,15 +1388,15 @@ class BaseChatOpenAI(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
if strict is not None and method != "function_calling": if strict is not None and method == "json_mode":
raise ValueError( raise ValueError(
"Argument `strict` is only supported for `method`='function_calling'" "Argument `strict` is not supported with `method`='json_mode'"
) )
is_pydantic_schema = _is_pydantic_class(schema) is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling": if method == "function_calling":
if schema is None: if schema is None:
raise ValueError( raise ValueError(
"schema must be specified when method is 'function_calling'. " "schema must be specified when method is not 'json_mode'. "
"Received None." "Received None."
) )
tool_name = convert_to_openai_tool(schema)["function"]["name"] tool_name = convert_to_openai_tool(schema)["function"]["name"]
@ -1339,6 +1422,20 @@ class BaseChatOpenAI(BaseChatModel):
if is_pydantic_schema if is_pydantic_schema
else JsonOutputParser() else JsonOutputParser()
) )
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
strict = strict if strict is not None else True
response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format)
output_parser = (
cast(Runnable, _oai_structured_outputs_parser)
if is_pydantic_schema
else JsonOutputParser()
)
else: else:
raise ValueError( raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or " f"Unrecognized method argument. Expected one of 'function_calling' or "
@ -1975,3 +2072,40 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
height = (width * 768) // height height = (width * 768) // height
width = 768 width = 768
return width, height return width, height
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], strict: bool
) -> Union[Dict, TypeBaseModel]:
if isinstance(schema, type) and is_basemodel_subclass(schema):
return schema
else:
function = convert_to_openai_function(schema, strict=strict)
function["schema"] = function.pop("parameters")
return {"type": "json_schema", "json_schema": function}
@chain
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
if ai_msg.additional_kwargs.get("parsed"):
return ai_msg.additional_kwargs["parsed"]
elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
else:
raise ValueError(
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
"field."
)
class OpenAIRefusalError(Exception):
"""Error raised when OpenAI Structured Outputs API returns a refusal.
When using OpenAI's Structured Outputs API with user-generated input, the model
may occasionally refuse to fulfill the request for safety reasons.
See here for more on refusals:
https://platform.openai.com/docs/guides/structured-outputs/refusals
.. versionadded:: 0.1.21
"""

View File

@ -1527,4 +1527,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "23d99a41f0cff5bf1869e8e18ac953a9802b0d1912eedcecca624650c1ff3af6" content-hash = "a08bed7f2e62b3f6c7fc52a31c2529b44d4e5adcc55aba5047be027596fdb31f"

View File

@ -24,7 +24,7 @@ ignore_missing_imports = true
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
langchain-core = { version = "^0.2.29rc1", allow-prereleases=true } langchain-core = { version = "^0.2.29rc1", allow-prereleases=true }
openai = "^1.32.0" openai = "^1.40.0"
tiktoken = ">=0.7,<1" tiktoken = ">=0.7,<1"
[tool.ruff.lint] [tool.ruff.lint]

View File

@ -1,7 +1,7 @@
"""Test ChatOpenAI chat model.""" """Test ChatOpenAI chat model."""
import base64 import base64
from typing import Any, AsyncIterator, List, Optional, cast from typing import Any, AsyncIterator, List, Literal, Optional, cast
import httpx import httpx
import openai import openai
@ -796,13 +796,21 @@ def test_tool_calling_strict() -> None:
next(model_with_invalid_tool_schema.stream(query)) next(model_with_invalid_tool_schema.stream(query))
def test_structured_output_strict() -> None: @pytest.mark.parametrize(
("model", "method", "strict"),
[("gpt-4o", "function_calling", True), ("gpt-4o-2024-08-06", "json_schema", None)],
)
def test_structured_output_strict(
model: str,
method: Literal["function_calling", "json_schema"],
strict: Optional[bool],
) -> None:
"""Test to verify structured output with strict=True.""" """Test to verify structured output with strict=True."""
from pydantic import BaseModel as BaseModelProper from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper from pydantic import Field as FieldProper
model = ChatOpenAI(model="gpt-4o", temperature=0) llm = ChatOpenAI(model=model, temperature=0)
class Joke(BaseModelProper): class Joke(BaseModelProper):
"""Joke to tell user.""" """Joke to tell user."""
@ -814,7 +822,7 @@ def test_structured_output_strict() -> None:
# Type ignoring since the interface only officially supports pydantic 1 # Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2. # or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures. # We'll need to do a pass updating the type signatures.
chat = model.with_structured_output(Joke, strict=True) # type: ignore[arg-type] chat = llm.with_structured_output(Joke, method=method, strict=strict)
result = chat.invoke("Tell me a joke about cats.") result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke) assert isinstance(result, Joke)
@ -822,7 +830,9 @@ def test_structured_output_strict() -> None:
assert isinstance(chunk, Joke) assert isinstance(chunk, Joke)
# Schema # Schema
chat = model.with_structured_output(Joke.model_json_schema(), strict=True) chat = llm.with_structured_output(
Joke.model_json_schema(), method=method, strict=strict
)
result = chat.invoke("Tell me a joke about cats.") result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict) assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"} assert set(result.keys()) == {"setup", "punchline"}
@ -831,3 +841,27 @@ def test_structured_output_strict() -> None:
assert isinstance(chunk, dict) assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"} assert set(chunk.keys()) == {"setup", "punchline"}
# Invalid schema with optional fields:
class InvalidJoke(BaseModelProper):
"""Joke to tell user."""
setup: str = FieldProper(description="question to set up a joke")
# Invalid field, can't have default value.
punchline: str = FieldProper(
default="foo", description="answer to resolve the joke"
)
chat = llm.with_structured_output(InvalidJoke, method=method, strict=strict)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))
chat = llm.with_structured_output(
InvalidJoke.model_json_schema(), method=method, strict=strict
)
with pytest.raises(openai.BadRequestError):
chat.invoke("Tell me a joke about cats.")
with pytest.raises(openai.BadRequestError):
next(chat.stream("Tell me a joke about cats."))

View File

@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper.""" """Test OpenAI Chat API wrapper."""
import json import json
from typing import Any, List, Type, Union from typing import Any, Dict, List, Literal, Optional, Type, Union
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -343,17 +343,32 @@ class MakeASandwich(BaseModel):
None, None,
], ],
) )
def test_bind_tools_tool_choice(tool_choice: Any) -> None: @pytest.mark.parametrize("strict", [True, False, None])
def test_bind_tools_tool_choice(tool_choice: Any, strict: Optional[bool]) -> None:
"""Test passing in manually construct tool call message.""" """Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.bind_tools(tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice) llm.bind_tools(
tools=[GenerateUsername, MakeASandwich], tool_choice=tool_choice, strict=strict
)
@pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()]) @pytest.mark.parametrize("schema", [GenerateUsername, GenerateUsername.schema()])
def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None: @pytest.mark.parametrize("method", ["json_schema", "function_calling", "json_mode"])
@pytest.mark.parametrize("include_raw", [True, False])
@pytest.mark.parametrize("strict", [True, False, None])
def test_with_structured_output(
schema: Union[Type, Dict[str, Any], None],
method: Literal["function_calling", "json_mode", "json_schema"],
include_raw: bool,
strict: Optional[bool],
) -> None:
"""Test passing in manually construct tool call message.""" """Test passing in manually construct tool call message."""
if method == "json_mode":
strict = None
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.with_structured_output(schema) llm.with_structured_output(
schema, method=method, strict=strict, include_raw=include_raw
)
def test_get_num_tokens_from_messages() -> None: def test_get_num_tokens_from_messages() -> None:

1157
poetry.lock generated

File diff suppressed because it is too large Load Diff