From 20b72a044c95a69bb248b80ef6c2a557f9adb99d Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 23 Jul 2024 09:01:26 -0400 Subject: [PATCH] standard-tests: Add BaseModel variations tests to with_structured_output (#24527) After this standard tests will test with the following combinations: 1. pydantic.BaseModel 2. pydantic.v1.BaseModel If ran within a matrix, it'll covert both pydantic.BaseModel originating from pydantic 1 and the one defined in pydantic 2. --- .../integration_tests/chat_models.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index 511eab620cd..c19baf85cf4 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -23,6 +23,7 @@ from langchain_standard_tests.unit_tests.chat_models import ( ChatModelTests, my_adder_tool, ) +from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION @tool @@ -211,10 +212,51 @@ class ChatModelIntegrationTests(ChatModelTests): assert tool_call["type"] == "tool_call" def test_structured_output(self, model: BaseChatModel) -> None: + """Test to verify structured output with a Pydantic model.""" if not self.has_tool_calling: pytest.skip("Test requires tool calling.") - class Joke(BaseModel): + from pydantic import BaseModel as BaseModelProper + from pydantic import Field as FieldProper + + class Joke(BaseModelProper): + """Joke to tell user.""" + + setup: str = FieldProper(description="question to set up a joke") + punchline: str = FieldProper(description="answer to resolve the joke") + + # Pydantic class + # Type ignoring since the interface only officially supports pydantic 1 + # or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2. + # We'll need to do a pass updating the type signatures. + chat = model.with_structured_output(Joke) # type: ignore[arg-type] + result = chat.invoke("Tell me a joke about cats.") + assert isinstance(result, Joke) + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, Joke) + + # Schema + chat = model.with_structured_output(Joke.schema()) + result = chat.invoke("Tell me a joke about cats.") + assert isinstance(result, dict) + assert set(result.keys()) == {"setup", "punchline"} + + for chunk in chat.stream("Tell me a joke about cats."): + assert isinstance(chunk, dict) + assert isinstance(chunk, dict) # for mypy + assert set(chunk.keys()) == {"setup", "punchline"} + + @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, "Test requires pydantic 2.") + def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: + """Test to verify compatibility with pydantic.v1.BaseModel. + + pydantic.v1.BaseModel is available in the pydantic 2 package. + """ + if not self.has_tool_calling: + pytest.skip("Test requires tool calling.") + + class Joke(BaseModel): # Uses langchain_core.pydantic_v1.BaseModel """Joke to tell user.""" setup: str = Field(description="question to set up a joke")