mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 12:58:59 +00:00
standard-tests: Add BaseModel variations tests to with_structured_output (#24527)
After this standard tests will test with the following combinations: 1. pydantic.BaseModel 2. pydantic.v1.BaseModel If ran within a matrix, it'll covert both pydantic.BaseModel originating from pydantic 1 and the one defined in pydantic 2.
This commit is contained in:
parent
70c71efcab
commit
20b72a044c
@ -23,6 +23,7 @@ from langchain_standard_tests.unit_tests.chat_models import (
|
||||
ChatModelTests,
|
||||
my_adder_tool,
|
||||
)
|
||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
|
||||
@tool
|
||||
@ -211,10 +212,51 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert tool_call["type"] == "tool_call"
|
||||
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
"""Test to verify structured output with a Pydantic model."""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
class Joke(BaseModel):
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
class Joke(BaseModelProper):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = FieldProper(description="question to set up a joke")
|
||||
punchline: str = FieldProper(description="answer to resolve the joke")
|
||||
|
||||
# Pydantic class
|
||||
# Type ignoring since the interface only officially supports pydantic 1
|
||||
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
|
||||
# We'll need to do a pass updating the type signatures.
|
||||
chat = model.with_structured_output(Joke) # type: ignore[arg-type]
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, Joke)
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, Joke)
|
||||
|
||||
# Schema
|
||||
chat = model.with_structured_output(Joke.schema())
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict) # for mypy
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, "Test requires pydantic 2.")
|
||||
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
|
||||
"""Test to verify compatibility with pydantic.v1.BaseModel.
|
||||
|
||||
pydantic.v1.BaseModel is available in the pydantic 2 package.
|
||||
"""
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
class Joke(BaseModel): # Uses langchain_core.pydantic_v1.BaseModel
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
|
Loading…
Reference in New Issue
Block a user