openai[patch]: get output_type when using with_structured_output (#26307)

- This allows pydantic to correctly resolve annotations necessary when
using openai new param `json_schema`

Resolves issue: #26250

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
liuhetian 2024-09-14 02:42:01 +08:00 committed by GitHub
parent 0f2b32ffa9
commit 7fc9e99e21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 7 deletions

View File

@ -65,7 +65,6 @@ from langchain_core.messages import (
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
@ -1421,7 +1420,7 @@ class BaseChatOpenAI(BaseChatModel):
strict=strict,
)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
output_parser: Runnable = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
@ -1445,11 +1444,12 @@ class BaseChatOpenAI(BaseChatModel):
strict = strict if strict is not None else True
response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format)
output_parser = (
cast(Runnable, _oai_structured_outputs_parser)
if is_pydantic_schema
else JsonOutputParser()
)
if is_pydantic_schema:
output_parser = _oai_structured_outputs_parser.with_types(
output_type=cast(type, schema)
)
else:
output_parser = JsonOutputParser()
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "

View File

@ -18,6 +18,7 @@ from langchain_core.messages import (
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel as BaseModelV2
from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import (
@ -694,3 +695,31 @@ def test_get_num_tokens_from_messages() -> None:
expected = 176
actual = llm.get_num_tokens_from_messages(messages)
assert expected == actual
class Foo(BaseModel):
bar: int
class FooV2(BaseModelV2):
bar: int
@pytest.mark.parametrize("schema", [Foo, FooV2])
def test_schema_from_with_structured_output(schema: Type) -> None:
"""Test schema from with_structured_output."""
llm = ChatOpenAI()
structured_llm = llm.with_structured_output(
schema, method="json_schema", strict=True
)
expected = {
"properties": {"bar": {"title": "Bar", "type": "integer"}},
"required": ["bar"],
"title": schema.__name__,
"type": "object",
}
actual = structured_llm.get_output_schema().schema()
assert actual == expected