tests[patch]: improve coverage of structured output tests (#29478)

This commit is contained in:
ccurme 2025-01-29 14:52:09 -05:00 committed by GitHub
parent c79274cb7c
commit 284c935b08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import base64
import json
from typing import List, Optional, cast
from typing import Any, List, Literal, Optional, cast
import httpx
import pytest
@ -29,7 +29,9 @@ from langchain_tests.unit_tests.chat_models import (
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
def _get_joke_class() -> type[BaseModel]:
def _get_joke_class(
schema_type: Literal["pydantic", "typeddict", "json_schema"],
) -> Any:
"""
:private:
"""
@ -40,7 +42,28 @@ def _get_joke_class() -> type[BaseModel]:
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
return Joke
def validate_joke(result: Any) -> bool:
return isinstance(result, Joke)
class JokeDict(TypedDict):
"""Joke to tell user."""
setup: Annotated[str, ..., "question to set up a joke"]
punchline: Annotated[str, ..., "answer to resolve the joke"]
def validate_joke_dict(result: Any) -> bool:
return all(key in ["setup", "punchline"] for key in result.keys())
if schema_type == "pydantic":
return Joke, validate_joke
elif schema_type == "typeddict":
return JokeDict, validate_joke_dict
elif schema_type == "json_schema":
return Joke.model_json_schema(), validate_joke_dict
else:
raise ValueError("Invalid schema type")
class _MagicFunctionSchema(BaseModel):
@ -1151,7 +1174,8 @@ class ChatModelIntegrationTests(ChatModelTests):
assert tool_call["args"].get("answer_style")
assert tool_call["type"] == "tool_call"
def test_structured_output(self, model: BaseChatModel) -> None:
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
"""Test to verify structured output is generated both on invoke and stream.
This test is optional and should be skipped if the model does not support
@ -1181,29 +1205,19 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
Joke = _get_joke_class()
# Pydantic class
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
validation_function(result)
for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
validation_function(chunk)
assert chunk
# Schema
chat = model.with_structured_output(
Joke.model_json_schema(), **self.structured_output_kwargs
)
result = chat.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
for chunk in chat.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
async def test_structured_output_async(self, model: BaseChatModel) -> None:
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
async def test_structured_output_async(
self, model: BaseChatModel, schema_type: str
) -> None:
"""Test to verify structured output is generated both on invoke and stream.
This test is optional and should be skipped if the model does not support
@ -1233,28 +1247,14 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
Joke = _get_joke_class()
# Pydantic class
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
validation_function(result)
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
# Schema
chat = model.with_structured_output(
Joke.model_json_schema(), **self.structured_output_kwargs
)
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
validation_function(chunk)
assert chunk
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: