tests[patch]: improve coverage of structured output tests (#29478)

This commit is contained in:
ccurme 2025-01-29 14:52:09 -05:00 committed by GitHub
parent c79274cb7c
commit 284c935b08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import base64 import base64
import json import json
from typing import List, Optional, cast from typing import Any, List, Literal, Optional, cast
import httpx import httpx
import pytest import pytest
@ -29,7 +29,9 @@ from langchain_tests.unit_tests.chat_models import (
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
def _get_joke_class() -> type[BaseModel]: def _get_joke_class(
schema_type: Literal["pydantic", "typeddict", "json_schema"],
) -> Any:
""" """
:private: :private:
""" """
@ -40,7 +42,28 @@ def _get_joke_class() -> type[BaseModel]:
setup: str = Field(description="question to set up a joke") setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke") punchline: str = Field(description="answer to resolve the joke")
return Joke def validate_joke(result: Any) -> bool:
return isinstance(result, Joke)
class JokeDict(TypedDict):
"""Joke to tell user."""
setup: Annotated[str, ..., "question to set up a joke"]
punchline: Annotated[str, ..., "answer to resolve the joke"]
def validate_joke_dict(result: Any) -> bool:
return all(key in ["setup", "punchline"] for key in result.keys())
if schema_type == "pydantic":
return Joke, validate_joke
elif schema_type == "typeddict":
return JokeDict, validate_joke_dict
elif schema_type == "json_schema":
return Joke.model_json_schema(), validate_joke_dict
else:
raise ValueError("Invalid schema type")
class _MagicFunctionSchema(BaseModel): class _MagicFunctionSchema(BaseModel):
@ -1151,7 +1174,8 @@ class ChatModelIntegrationTests(ChatModelTests):
assert tool_call["args"].get("answer_style") assert tool_call["args"].get("answer_style")
assert tool_call["type"] == "tool_call" assert tool_call["type"] == "tool_call"
def test_structured_output(self, model: BaseChatModel) -> None: @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
"""Test to verify structured output is generated both on invoke and stream. """Test to verify structured output is generated both on invoke and stream.
This test is optional and should be skipped if the model does not support This test is optional and should be skipped if the model does not support
@ -1181,29 +1205,19 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
Joke = _get_joke_class() schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
# Pydantic class chat = model.with_structured_output(schema, **self.structured_output_kwargs)
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
result = chat.invoke("Tell me a joke about cats.") result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke) validation_function(result)
for chunk in chat.stream("Tell me a joke about cats."): for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke) validation_function(chunk)
assert chunk
# Schema @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
chat = model.with_structured_output( async def test_structured_output_async(
Joke.model_json_schema(), **self.structured_output_kwargs self, model: BaseChatModel, schema_type: str
) ) -> None:
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
async def test_structured_output_async(self, model: BaseChatModel) -> None:
"""Test to verify structured output is generated both on invoke and stream. """Test to verify structured output is generated both on invoke and stream.
This test is optional and should be skipped if the model does not support This test is optional and should be skipped if the model does not support
@ -1233,28 +1247,14 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
Joke = _get_joke_class() schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
# Pydantic class
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
result = await chat.ainvoke("Tell me a joke about cats.") result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke) validation_function(result)
async for chunk in chat.astream("Tell me a joke about cats."): async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke) validation_function(chunk)
assert chunk
# Schema
chat = model.with_structured_output(
Joke.model_json_schema(), **self.structured_output_kwargs
)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.") @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: