mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
openai[patch]: get output_type when using with_structured_output (#26307)
- This allows pydantic to correctly resolve annotations necessary when using openai new param `json_schema` Resolves issue: #26250 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
0f2b32ffa9
commit
7fc9e99e21
@ -65,7 +65,6 @@ from langchain_core.messages import (
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
@ -1421,7 +1420,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
strict=strict,
|
||||
)
|
||||
if is_pydantic_schema:
|
||||
output_parser: OutputParserLike = PydanticToolsParser(
|
||||
output_parser: Runnable = PydanticToolsParser(
|
||||
tools=[schema], # type: ignore[list-item]
|
||||
first_tool_only=True, # type: ignore[list-item]
|
||||
)
|
||||
@ -1445,11 +1444,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
strict = strict if strict is not None else True
|
||||
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||
llm = self.bind(response_format=response_format)
|
||||
output_parser = (
|
||||
cast(Runnable, _oai_structured_outputs_parser)
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
if is_pydantic_schema:
|
||||
output_parser = _oai_structured_outputs_parser.with_types(
|
||||
output_type=cast(type, schema)
|
||||
)
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||
|
@ -18,6 +18,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from pydantic import BaseModel as BaseModelV2
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai.chat_models.base import (
|
||||
@ -694,3 +695,31 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
expected = 176
|
||||
actual = llm.get_num_tokens_from_messages(messages)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
class Foo(BaseModel):
|
||||
bar: int
|
||||
|
||||
|
||||
class FooV2(BaseModelV2):
|
||||
bar: int
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema", [Foo, FooV2])
|
||||
def test_schema_from_with_structured_output(schema: Type) -> None:
|
||||
"""Test schema from with_structured_output."""
|
||||
|
||||
llm = ChatOpenAI()
|
||||
|
||||
structured_llm = llm.with_structured_output(
|
||||
schema, method="json_schema", strict=True
|
||||
)
|
||||
|
||||
expected = {
|
||||
"properties": {"bar": {"title": "Bar", "type": "integer"}},
|
||||
"required": ["bar"],
|
||||
"title": schema.__name__,
|
||||
"type": "object",
|
||||
}
|
||||
actual = structured_llm.get_output_schema().schema()
|
||||
assert actual == expected
|
||||
|
Loading…
Reference in New Issue
Block a user