refactoring

This commit is contained in:
Mason Daugherty 2025-08-04 15:54:07 -04:00
parent 2f8470d7f2
commit 2290984cfa
No known key found for this signature in database
3 changed files with 71 additions and 47 deletions

View File

@ -27,40 +27,44 @@ def test_structured_output(method: str) -> None:
query = "Tell me a joke about cats."
# Pydantic
structured_llm = llm.with_structured_output(Joke, method=method) # type: ignore[arg-type]
result = structured_llm.invoke(query)
assert isinstance(result, Joke)
if method == "function_calling":
structured_llm = llm.with_structured_output(Joke, method="function_calling")
result = structured_llm.invoke(query)
assert isinstance(result, Joke)
for chunk in structured_llm.stream(query):
assert isinstance(chunk, Joke)
for chunk in structured_llm.stream(query):
assert isinstance(chunk, Joke)
# JSON Schema
structured_llm = llm.with_structured_output(Joke.model_json_schema(), method=method) # type: ignore[arg-type]
result = structured_llm.invoke(query)
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
if method == "json_schema":
structured_llm = llm.with_structured_output(
Joke.model_json_schema(), method="json_schema"
)
result = structured_llm.invoke(query)
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
for chunk in structured_llm.stream(query):
for chunk in structured_llm.stream(query):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict)
assert isinstance(chunk, dict)
assert set(chunk.keys()) == {"setup", "punchline"}
assert set(chunk.keys()) == {"setup", "punchline"}
# Typed Dict
class JokeSchema(TypedDict):
"""Joke to tell user."""
# Typed Dict
class JokeSchema(TypedDict):
"""Joke to tell user."""
setup: Annotated[str, "question to set up a joke"]
punchline: Annotated[str, "answer to resolve the joke"]
setup: Annotated[str, "question to set up a joke"]
punchline: Annotated[str, "answer to resolve the joke"]
structured_llm = llm.with_structured_output(JokeSchema, method=method) # type: ignore[arg-type]
result = structured_llm.invoke(query)
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
structured_llm = llm.with_structured_output(JokeSchema, method="json_schema")
result = structured_llm.invoke(query)
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}
for chunk in structured_llm.stream(query):
for chunk in structured_llm.stream(query):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict)
assert isinstance(chunk, dict)
assert set(chunk.keys()) == {"setup", "punchline"}
assert set(chunk.keys()) == {"setup", "punchline"}
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])

View File

@ -34,6 +34,9 @@ from langchain_core.messages.content_blocks import (
create_plaintext_block,
create_text_block,
create_video_block,
is_reasoning_block,
is_text_block,
is_tool_call_block,
)
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, HumanMessage
from langchain_core.tools import tool
@ -90,7 +93,7 @@ def _validate_tool_call_message(message: AIMessage) -> None:
tool_call_blocks = [
block
for block in message.content
if isinstance(block, dict) and block.get("type") == "tool_call"
if isinstance(block, dict) and is_tool_call_block(block)
]
assert len(tool_call_blocks) >= 1
@ -170,7 +173,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
text_blocks = [
block
for block in result.content
if isinstance(block, dict) and block.get("type") == "text"
if isinstance(block, dict) and is_text_block(block)
]
assert len(text_blocks) > 0
if result.text:
@ -244,7 +247,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
reasoning_blocks = [
block
for block in result.content
if isinstance(block, dict) and block.get("type") == "reasoning"
if isinstance(block, dict) and is_reasoning_block(block)
]
assert len(reasoning_blocks) > 0
@ -266,7 +269,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
for block in result.content:
if (
isinstance(block, dict)
and block.get("type") == "text"
and is_text_block(block)
and "annotations" in block
):
annotations = cast("list[dict[str, Any]]", block.get("annotations", []))
@ -343,7 +346,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
def test_tool_calling_with_content_blocks(self, model: BaseChatModelV1) -> None:
"""Test tool calling with content blocks."""
if not self.supports_enhanced_tool_calls:
if not self.has_tool_calling:
pytest.skip("Model does not support tool calls.")
@tool

View File

@ -20,6 +20,9 @@ from langchain_core.messages.content_blocks import (
create_image_block,
create_non_standard_block,
create_text_block,
is_reasoning_block,
is_text_block,
is_tool_call_block,
)
from langchain_core.messages.v1 import AIMessage, HumanMessage
from langchain_core.tools import tool
@ -55,7 +58,25 @@ class ChatModelV1Tests(BaseStandardTests):
# Content Block Support Properties
@property
def supports_content_blocks_v1(self) -> bool:
"""Whether the model supports content blocks v1 format."""
"""Whether the model supports content blocks v1 format.
Defualts to True. This should not be overridden by a ChatV1 subclass. You may
override the following properties to enable specific content block support.
Each defaults to False:
- ``supports_reasoning_content_blocks``
- ``supports_plaintext_content_blocks``
- ``supports_file_content_blocks``
- ``supports_image_content_blocks``
- ``supports_audio_content_blocks``
- ``supports_video_content_blocks``
- ``supports_citations``
- ``supports_web_search_blocks``
- ``supports_enhanced_tool_calls``
- ``supports_invalid_tool_calls``
- ``supports_tool_call_chunks``
"""
return True
@property
@ -65,7 +86,11 @@ class ChatModelV1Tests(BaseStandardTests):
@property
def supports_text_content_blocks(self) -> bool:
"""Whether the model supports ``TextContentBlock``."""
"""Whether the model supports ``TextContentBlock``.
This is a minimum requirement for v1 chat models.
"""
return self.supports_content_blocks_v1
@property
@ -108,21 +133,11 @@ class ChatModelV1Tests(BaseStandardTests):
"""Whether the model supports ``WebSearchCall``/``WebSearchResult`` blocks."""
return False
@property
def supports_enhanced_tool_calls(self) -> bool:
"""Whether the model supports ``ToolCall`` format with content blocks."""
return self.has_tool_calling and self.supports_content_blocks_v1
@property
def supports_invalid_tool_calls(self) -> bool:
"""Whether the model can handle ``InvalidToolCall`` blocks."""
return False
@property
def supports_tool_call_chunks(self) -> bool:
"""Whether the model supports streaming ``ToolCallChunk`` blocks."""
return self.supports_enhanced_tool_calls
class ChatModelV1UnitTests(ChatModelV1Tests):
"""Unit tests for chat models with content blocks v1 support.
@ -289,7 +304,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
``TextContentBlock`` objects instead of plain strings.
"""
if not self.supports_text_content_blocks:
pytest.skip("Model does not support TextContentBlock.")
pytest.skip("Model does not support TextContentBlock (rare!)")
text_block = create_text_block("Hello, world!")
message = HumanMessage(content=[text_block])
@ -303,7 +318,9 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
if not (
self.supports_text_content_blocks and self.supports_image_content_blocks
):
pytest.skip("Model does not support mixed content blocks.")
pytest.skip(
"Model doesn't support mixed content blocks (concurrent text and image)"
)
content_blocks: list[ContentBlock] = [
create_text_block("Describe this image:"),
@ -332,7 +349,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
reasoning_blocks = [
block
for block in result.content
if isinstance(block, dict) and block.get("type") == "reasoning"
if isinstance(block, dict) and is_reasoning_block(block)
]
assert len(reasoning_blocks) > 0
@ -351,7 +368,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
for block in content_list:
if (
isinstance(block, dict)
and block.get("type") == "text"
and is_text_block(block)
and "annotations" in block
and isinstance(block.get("annotations"), list)
and len(cast(list, block.get("annotations", []))) > 0
@ -394,7 +411,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
self, model: BaseChatModelV1
) -> None:
"""Test enhanced tool calling with content blocks format."""
if not self.supports_enhanced_tool_calls:
if not self.has_tool_calling:
pytest.skip("Model does not support enhanced tool calls.")
@tool
@ -413,7 +430,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
tool_call_blocks = [
block
for block in result.content
if isinstance(block, dict) and block.get("type") == "tool_call"
if isinstance(block, dict) and is_tool_call_block(block)
]
assert len(tool_call_blocks) > 0
# Backwards compat?