openai[patch]: ChatOpenAI.with_structured_output json_schema support (#25123)

This commit is contained in:
Bagatur
2024-08-07 08:09:07 -07:00
committed by GitHub
parent 0ba125c3cd
commit 09fbce13c5
6 changed files with 895 additions and 603 deletions

View File

@@ -7,6 +7,7 @@ import json
import logging
import os
import sys
import warnings
from io import BytesIO
from math import ceil
from operator import itemgetter
@@ -27,7 +28,6 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from urllib.parse import urlparse
@@ -74,7 +74,7 @@ from langchain_core.output_parsers.openai_tools import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough, chain
from langchain_core.runnables.config import run_in_executor
from langchain_core.tools import BaseTool
from langchain_core.utils import (
@@ -86,7 +86,11 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function,
convert_to_openai_tool,
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.pydantic import (
PydanticBaseModel,
TypeBaseModel,
is_basemodel_subclass,
)
from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__)
@@ -298,6 +302,8 @@ class _AllReturnType(TypedDict):
class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private:
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
@@ -445,9 +451,8 @@ class BaseChatOpenAI(BaseChatModel):
) from e
values["http_client"] = httpx.Client(proxy=values["openai_proxy"])
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
values["root_client"] = openai.OpenAI(**client_params, **sync_specific)
values["client"] = values["root_client"].chat.completions
if not values.get("async_client"):
if values["openai_proxy"] and not values["http_async_client"]:
try:
@@ -461,10 +466,10 @@ class BaseChatOpenAI(BaseChatModel):
proxy=values["openai_proxy"]
)
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncOpenAI(
values["root_async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
)
values["async_client"] = values["root_async_client"].chat.completions
return values
@property
@@ -525,13 +530,32 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = self._generate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response:
@@ -594,13 +618,21 @@ class BaseChatOpenAI(BaseChatModel):
)
return generate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
generation_info = None
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response = self.root_client.beta.chat.completions.parse(**payload)
elif self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
generation_info = None
return self._create_chat_result(response, generation_info)
def _get_request_payload(
@@ -625,18 +657,19 @@ class BaseChatOpenAI(BaseChatModel):
generation_info: Optional[Dict] = None,
) -> ChatResult:
generations = []
if not isinstance(response, dict):
response = response.model_dump()
response_dict = (
response if isinstance(response, dict) else response.model_dump()
)
# Sometimes the AI Model calling will get error, we should raise it.
# Otherwise, the next code 'choices.extend(response["choices"])'
# will throw a "TypeError: 'NoneType' object is not iterable" error
# to mask the true error. Because 'response["choices"]' is None.
if response.get("error"):
raise ValueError(response.get("error"))
if response_dict.get("error"):
raise ValueError(response_dict.get("error"))
token_usage = response.get("usage", {})
for res in response["choices"]:
token_usage = response_dict.get("usage", {})
for res in response_dict["choices"]:
message = _convert_dict_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
@@ -656,9 +689,19 @@ class BaseChatOpenAI(BaseChatModel):
generations.append(gen)
llm_output = {
"token_usage": token_usage,
"model_name": response.get("model", self.model_name),
"system_fingerprint": response.get("system_fingerprint", ""),
"model_name": response_dict.get("model", self.model_name),
"system_fingerprint": response_dict.get("system_fingerprint", ""),
}
if isinstance(response, openai.BaseModel) and getattr(
response, "choices", None
):
message = response.choices[0].message # type: ignore[attr-defined]
if hasattr(message, "parsed"):
generations[0].message.additional_kwargs["parsed"] = message.parsed
if hasattr(message, "refusal"):
generations[0].message.additional_kwargs["refusal"] = message.refusal
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
@@ -671,13 +714,31 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = await self._agenerate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
base_generation_info = {}
async with response:
is_first_chunk = True
async for chunk in response:
@@ -745,13 +806,23 @@ class BaseChatOpenAI(BaseChatModel):
)
return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs)
if self.include_response_headers:
generation_info = None
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response = await self.root_async_client.beta.chat.completions.parse(
**payload
)
elif self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
generation_info = None
return await run_in_executor(
None, self._create_chat_result, response, generation_info
)
@@ -1028,34 +1099,13 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
# TODO: Fix typing.
@overload # type: ignore[override]
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: Literal[True] = True,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _AllReturnType]: ...
@overload
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: Literal[False] = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]: ...
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
@@ -1065,10 +1115,12 @@ class BaseChatOpenAI(BaseChatModel):
.. versionchanged:: 0.1.21
Support for ``strict`` argument added.
Support for ``method`` = "json_schema" added.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.1.20),
@@ -1085,12 +1137,36 @@ class BaseChatOpenAI(BaseChatModel):
Added support for TypedDict class.
method:
The method for steering model generation, one of "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" then OpenAI's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
The method for steering model generation, one of:
- "function_calling":
Uses OpenAI's tool-calling (formerly called function calling)
API: https://platform.openai.com/docs/guides/function-calling
- "json_schema":
Uses OpenAI's Structured Output API:
https://platform.openai.com/docs/guides/structured-outputs.
Supported for "gpt-4o-mini", "gpt-4o-2024-08-06", and later
models.
- "json_mode":
Uses OpenAI's JSON mode. Note that if using JSON mode then you
must include instructions for formatting the output into the
desired schema into the model call:
https://platform.openai.com/docs/guides/structured-outputs/json-mode
Learn more about the differences between the methods and which models
support which methods here:
- https://platform.openai.com/docs/guides/structured-outputs/structured-outputs-vs-json-mode
- https://platform.openai.com/docs/guides/structured-outputs/function-calling-vs-response-format
.. versionchanged:: 0.1.21
Added support for "json_schema".
.. note:: Planned breaking change in version `0.2.0`
``method`` default will be changed to "json_schema" from
"function_calling".
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
@@ -1098,14 +1174,20 @@ class BaseChatOpenAI(BaseChatModel):
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
strict: If True and ``method`` = "function_calling", model output is
guaranteed to exactly match the schema
If True, the input schema will also be
validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
If False, input schema will not be validated and model output will not
be validated.
If None, ``strict`` argument will not be passed to the model.
strict:
- True:
Model output is guaranteed to exactly match the schema.
The input schema will also be validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas.
- False:
Input schema will not be validated and model output will not be
validated.
- None:
``strict`` argument will not be passed to the model.
If ``method`` is "json_schema" defaults to True. If ``method`` is
"function_calling" or "json_mode" defaults to None. Can only be
non-null if ``method`` is "function_calling" or "json_schema".
.. versionadded:: 0.1.21
@@ -1124,9 +1206,10 @@ class BaseChatOpenAI(BaseChatModel):
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
If ``include_raw`` is True, then Runnable outputs a dict with keys:
- ``"raw"``: BaseMessage
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- ``"parsing_error"``: Optional[BaseException]
- "raw": BaseMessage
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
Example: schema=Pydantic class, method="function_calling", include_raw=False, strict=True:
.. note:: Valid schemas when using ``strict`` = True
@@ -1305,15 +1388,15 @@ class BaseChatOpenAI(BaseChatModel):
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
if strict is not None and method != "function_calling":
if strict is not None and method == "json_mode":
raise ValueError(
"Argument `strict` is only supported for `method`='function_calling'"
"Argument `strict` is not supported with `method`='json_mode'"
)
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
@@ -1339,6 +1422,20 @@ class BaseChatOpenAI(BaseChatModel):
if is_pydantic_schema
else JsonOutputParser()
)
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
strict = strict if strict is not None else True
response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format)
output_parser = (
cast(Runnable, _oai_structured_outputs_parser)
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
@@ -1975,3 +2072,40 @@ def _resize(width: int, height: int) -> Tuple[int, int]:
height = (width * 768) // height
width = 768
return width, height
def _convert_to_openai_response_format(
schema: Union[Dict[str, Any], Type], strict: bool
) -> Union[Dict, TypeBaseModel]:
if isinstance(schema, type) and is_basemodel_subclass(schema):
return schema
else:
function = convert_to_openai_function(schema, strict=strict)
function["schema"] = function.pop("parameters")
return {"type": "json_schema", "json_schema": function}
@chain
def _oai_structured_outputs_parser(ai_msg: AIMessage) -> PydanticBaseModel:
if ai_msg.additional_kwargs.get("parsed"):
return ai_msg.additional_kwargs["parsed"]
elif ai_msg.additional_kwargs.get("refusal"):
raise OpenAIRefusalError(ai_msg.additional_kwargs["refusal"])
else:
raise ValueError(
"Structured Output response does not have a 'parsed' field nor a 'refusal' "
"field."
)
class OpenAIRefusalError(Exception):
"""Error raised when OpenAI Structured Outputs API returns a refusal.
When using OpenAI's Structured Outputs API with user-generated input, the model
may occasionally refuse to fulfill the request for safety reasons.
See here for more on refusals:
https://platform.openai.com/docs/guides/structured-outputs/refusals
.. versionadded:: 0.1.21
"""