mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
openai[patch]: get output_type when using with_structured_output (#26307)
- This allows pydantic to correctly resolve annotations necessary when using openai new param `json_schema` Resolves issue: #26250 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
0f2b32ffa9
commit
7fc9e99e21
@ -65,7 +65,6 @@ from langchain_core.messages import (
|
|||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.messages.tool import tool_call_chunk
|
from langchain_core.messages.tool import tool_call_chunk
|
||||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
JsonOutputKeyToolsParser,
|
JsonOutputKeyToolsParser,
|
||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
@ -1421,7 +1420,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
strict=strict,
|
strict=strict,
|
||||||
)
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: Runnable = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
first_tool_only=True, # type: ignore[list-item]
|
first_tool_only=True, # type: ignore[list-item]
|
||||||
)
|
)
|
||||||
@ -1445,11 +1444,12 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
strict = strict if strict is not None else True
|
strict = strict if strict is not None else True
|
||||||
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||||
llm = self.bind(response_format=response_format)
|
llm = self.bind(response_format=response_format)
|
||||||
output_parser = (
|
if is_pydantic_schema:
|
||||||
cast(Runnable, _oai_structured_outputs_parser)
|
output_parser = _oai_structured_outputs_parser.with_types(
|
||||||
if is_pydantic_schema
|
output_type=cast(type, schema)
|
||||||
else JsonOutputParser()
|
)
|
||||||
)
|
else:
|
||||||
|
output_parser = JsonOutputParser()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
f"Unrecognized method argument. Expected one of 'function_calling' or "
|
||||||
|
@ -18,6 +18,7 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from pydantic import BaseModel as BaseModelV2
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_openai.chat_models.base import (
|
from langchain_openai.chat_models.base import (
|
||||||
@ -694,3 +695,31 @@ def test_get_num_tokens_from_messages() -> None:
|
|||||||
expected = 176
|
expected = 176
|
||||||
actual = llm.get_num_tokens_from_messages(messages)
|
actual = llm.get_num_tokens_from_messages(messages)
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(BaseModel):
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
|
||||||
|
class FooV2(BaseModelV2):
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("schema", [Foo, FooV2])
|
||||||
|
def test_schema_from_with_structured_output(schema: Type) -> None:
|
||||||
|
"""Test schema from with_structured_output."""
|
||||||
|
|
||||||
|
llm = ChatOpenAI()
|
||||||
|
|
||||||
|
structured_llm = llm.with_structured_output(
|
||||||
|
schema, method="json_schema", strict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"properties": {"bar": {"title": "Bar", "type": "integer"}},
|
||||||
|
"required": ["bar"],
|
||||||
|
"title": schema.__name__,
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
actual = structured_llm.get_output_schema().schema()
|
||||||
|
assert actual == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user