openai[patch]: get output_type when using with_structured_output (#26307)

- This allows pydantic to correctly resolve annotations necessary when
using openai new param `json_schema`

Resolves issue: #26250

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
liuhetian 2024-09-14 02:42:01 +08:00 committed by GitHub
parent 0f2b32ffa9
commit 7fc9e99e21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 7 deletions

View File

@ -65,7 +65,6 @@ from langchain_core.messages import (
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.tool import tool_call_chunk from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser, JsonOutputKeyToolsParser,
PydanticToolsParser, PydanticToolsParser,
@ -1421,7 +1420,7 @@ class BaseChatOpenAI(BaseChatModel):
strict=strict, strict=strict,
) )
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: Runnable = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item] first_tool_only=True, # type: ignore[list-item]
) )
@ -1445,11 +1444,12 @@ class BaseChatOpenAI(BaseChatModel):
strict = strict if strict is not None else True strict = strict if strict is not None else True
response_format = _convert_to_openai_response_format(schema, strict=strict) response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format) llm = self.bind(response_format=response_format)
output_parser = ( if is_pydantic_schema:
cast(Runnable, _oai_structured_outputs_parser) output_parser = _oai_structured_outputs_parser.with_types(
if is_pydantic_schema output_type=cast(type, schema)
else JsonOutputParser() )
) else:
output_parser = JsonOutputParser()
else: else:
raise ValueError( raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or " f"Unrecognized method argument. Expected one of 'function_calling' or "

View File

@ -18,6 +18,7 @@ from langchain_core.messages import (
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel as BaseModelV2
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import ( from langchain_openai.chat_models.base import (
@ -694,3 +695,31 @@ def test_get_num_tokens_from_messages() -> None:
expected = 176 expected = 176
actual = llm.get_num_tokens_from_messages(messages) actual = llm.get_num_tokens_from_messages(messages)
assert expected == actual assert expected == actual
class Foo(BaseModel):
bar: int
class FooV2(BaseModelV2):
bar: int
@pytest.mark.parametrize("schema", [Foo, FooV2])
def test_schema_from_with_structured_output(schema: Type) -> None:
"""Test schema from with_structured_output."""
llm = ChatOpenAI()
structured_llm = llm.with_structured_output(
schema, method="json_schema", strict=True
)
expected = {
"properties": {"bar": {"title": "Bar", "type": "integer"}},
"required": ["bar"],
"title": schema.__name__,
"type": "object",
}
actual = structured_llm.get_output_schema().schema()
assert actual == expected