standard-tests[patch]: add async test for structured output (#26527)

This commit is contained in:
ccurme 2024-09-16 11:15:23 -04:00 committed by GitHub
parent 1ab181f514
commit 88bc15d69b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -44,6 +44,13 @@ def magic_function_no_args() -> int:
return 5
class Joke(BaseModel):
"""Joke to tell user."""
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
def _validate_tool_call_message(message: BaseMessage) -> None:
assert isinstance(message, AIMessage)
assert len(message.tool_calls) == 1
@ -240,12 +247,6 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
class Joke(BaseModel):
"""Joke to tell user."""
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
@ -268,6 +269,33 @@ class ChatModelIntegrationTests(ChatModelTests):
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
async def test_structured_output_async(self, model: BaseChatModel) -> None:
"""Test to verify structured output with a Pydantic model."""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = model.with_structured_output(Joke) # type: ignore[arg-type]
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, Joke)
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)
# Schema
chat = model.with_structured_output(Joke.model_json_schema())
result = await chat.ainvoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
async for chunk in chat.astream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
"""Test to verify compatibility with pydantic.v1.BaseModel.