mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
tests[patch]: improve coverage of structured output tests (#29478)
This commit is contained in:
parent
c79274cb7c
commit
284c935b08
@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
from typing import List, Optional, cast
|
||||
from typing import Any, List, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@ -29,7 +29,9 @@ from langchain_tests.unit_tests.chat_models import (
|
||||
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
|
||||
def _get_joke_class() -> type[BaseModel]:
|
||||
def _get_joke_class(
|
||||
schema_type: Literal["pydantic", "typeddict", "json_schema"],
|
||||
) -> Any:
|
||||
"""
|
||||
:private:
|
||||
"""
|
||||
@ -40,7 +42,28 @@ def _get_joke_class() -> type[BaseModel]:
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
|
||||
return Joke
|
||||
def validate_joke(result: Any) -> bool:
|
||||
return isinstance(result, Joke)
|
||||
|
||||
class JokeDict(TypedDict):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: Annotated[str, ..., "question to set up a joke"]
|
||||
punchline: Annotated[str, ..., "answer to resolve the joke"]
|
||||
|
||||
def validate_joke_dict(result: Any) -> bool:
|
||||
return all(key in ["setup", "punchline"] for key in result.keys())
|
||||
|
||||
if schema_type == "pydantic":
|
||||
return Joke, validate_joke
|
||||
|
||||
elif schema_type == "typeddict":
|
||||
return JokeDict, validate_joke_dict
|
||||
|
||||
elif schema_type == "json_schema":
|
||||
return Joke.model_json_schema(), validate_joke_dict
|
||||
else:
|
||||
raise ValueError("Invalid schema type")
|
||||
|
||||
|
||||
class _MagicFunctionSchema(BaseModel):
|
||||
@ -1151,7 +1174,8 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
assert tool_call["args"].get("answer_style")
|
||||
assert tool_call["type"] == "tool_call"
|
||||
|
||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
||||
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
|
||||
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
|
||||
"""Test to verify structured output is generated both on invoke and stream.
|
||||
|
||||
This test is optional and should be skipped if the model does not support
|
||||
@ -1181,29 +1205,19 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
Joke = _get_joke_class()
|
||||
# Pydantic class
|
||||
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
|
||||
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
|
||||
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, Joke)
|
||||
validation_function(result)
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, Joke)
|
||||
validation_function(chunk)
|
||||
assert chunk
|
||||
|
||||
# Schema
|
||||
chat = model.with_structured_output(
|
||||
Joke.model_json_schema(), **self.structured_output_kwargs
|
||||
)
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict) # for mypy
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
||||
async def test_structured_output_async(self, model: BaseChatModel) -> None:
|
||||
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
|
||||
async def test_structured_output_async(
|
||||
self, model: BaseChatModel, schema_type: str
|
||||
) -> None:
|
||||
"""Test to verify structured output is generated both on invoke and stream.
|
||||
|
||||
This test is optional and should be skipped if the model does not support
|
||||
@ -1233,28 +1247,14 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
Joke = _get_joke_class()
|
||||
|
||||
# Pydantic class
|
||||
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
|
||||
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
|
||||
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
|
||||
result = await chat.ainvoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, Joke)
|
||||
validation_function(result)
|
||||
|
||||
async for chunk in chat.astream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, Joke)
|
||||
|
||||
# Schema
|
||||
chat = model.with_structured_output(
|
||||
Joke.model_json_schema(), **self.structured_output_kwargs
|
||||
)
|
||||
result = await chat.ainvoke("Tell me a joke about cats.")
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
async for chunk in chat.astream("Tell me a joke about cats."):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict) # for mypy
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
validation_function(chunk)
|
||||
assert chunk
|
||||
|
||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
|
||||
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user