mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
refactoring
This commit is contained in:
parent
2f8470d7f2
commit
2290984cfa
@ -27,7 +27,8 @@ def test_structured_output(method: str) -> None:
|
|||||||
query = "Tell me a joke about cats."
|
query = "Tell me a joke about cats."
|
||||||
|
|
||||||
# Pydantic
|
# Pydantic
|
||||||
structured_llm = llm.with_structured_output(Joke, method=method) # type: ignore[arg-type]
|
if method == "function_calling":
|
||||||
|
structured_llm = llm.with_structured_output(Joke, method="function_calling")
|
||||||
result = structured_llm.invoke(query)
|
result = structured_llm.invoke(query)
|
||||||
assert isinstance(result, Joke)
|
assert isinstance(result, Joke)
|
||||||
|
|
||||||
@ -35,7 +36,10 @@ def test_structured_output(method: str) -> None:
|
|||||||
assert isinstance(chunk, Joke)
|
assert isinstance(chunk, Joke)
|
||||||
|
|
||||||
# JSON Schema
|
# JSON Schema
|
||||||
structured_llm = llm.with_structured_output(Joke.model_json_schema(), method=method) # type: ignore[arg-type]
|
if method == "json_schema":
|
||||||
|
structured_llm = llm.with_structured_output(
|
||||||
|
Joke.model_json_schema(), method="json_schema"
|
||||||
|
)
|
||||||
result = structured_llm.invoke(query)
|
result = structured_llm.invoke(query)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert set(result.keys()) == {"setup", "punchline"}
|
assert set(result.keys()) == {"setup", "punchline"}
|
||||||
@ -52,7 +56,7 @@ def test_structured_output(method: str) -> None:
|
|||||||
setup: Annotated[str, "question to set up a joke"]
|
setup: Annotated[str, "question to set up a joke"]
|
||||||
punchline: Annotated[str, "answer to resolve the joke"]
|
punchline: Annotated[str, "answer to resolve the joke"]
|
||||||
|
|
||||||
structured_llm = llm.with_structured_output(JokeSchema, method=method) # type: ignore[arg-type]
|
structured_llm = llm.with_structured_output(JokeSchema, method="json_schema")
|
||||||
result = structured_llm.invoke(query)
|
result = structured_llm.invoke(query)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
assert set(result.keys()) == {"setup", "punchline"}
|
assert set(result.keys()) == {"setup", "punchline"}
|
||||||
|
@ -34,6 +34,9 @@ from langchain_core.messages.content_blocks import (
|
|||||||
create_plaintext_block,
|
create_plaintext_block,
|
||||||
create_text_block,
|
create_text_block,
|
||||||
create_video_block,
|
create_video_block,
|
||||||
|
is_reasoning_block,
|
||||||
|
is_text_block,
|
||||||
|
is_tool_call_block,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, HumanMessage
|
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, HumanMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@ -90,7 +93,7 @@ def _validate_tool_call_message(message: AIMessage) -> None:
|
|||||||
tool_call_blocks = [
|
tool_call_blocks = [
|
||||||
block
|
block
|
||||||
for block in message.content
|
for block in message.content
|
||||||
if isinstance(block, dict) and block.get("type") == "tool_call"
|
if isinstance(block, dict) and is_tool_call_block(block)
|
||||||
]
|
]
|
||||||
assert len(tool_call_blocks) >= 1
|
assert len(tool_call_blocks) >= 1
|
||||||
|
|
||||||
@ -170,7 +173,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
|||||||
text_blocks = [
|
text_blocks = [
|
||||||
block
|
block
|
||||||
for block in result.content
|
for block in result.content
|
||||||
if isinstance(block, dict) and block.get("type") == "text"
|
if isinstance(block, dict) and is_text_block(block)
|
||||||
]
|
]
|
||||||
assert len(text_blocks) > 0
|
assert len(text_blocks) > 0
|
||||||
if result.text:
|
if result.text:
|
||||||
@ -244,7 +247,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
|||||||
reasoning_blocks = [
|
reasoning_blocks = [
|
||||||
block
|
block
|
||||||
for block in result.content
|
for block in result.content
|
||||||
if isinstance(block, dict) and block.get("type") == "reasoning"
|
if isinstance(block, dict) and is_reasoning_block(block)
|
||||||
]
|
]
|
||||||
assert len(reasoning_blocks) > 0
|
assert len(reasoning_blocks) > 0
|
||||||
|
|
||||||
@ -266,7 +269,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
|||||||
for block in result.content:
|
for block in result.content:
|
||||||
if (
|
if (
|
||||||
isinstance(block, dict)
|
isinstance(block, dict)
|
||||||
and block.get("type") == "text"
|
and is_text_block(block)
|
||||||
and "annotations" in block
|
and "annotations" in block
|
||||||
):
|
):
|
||||||
annotations = cast("list[dict[str, Any]]", block.get("annotations", []))
|
annotations = cast("list[dict[str, Any]]", block.get("annotations", []))
|
||||||
@ -343,7 +346,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
|||||||
|
|
||||||
def test_tool_calling_with_content_blocks(self, model: BaseChatModelV1) -> None:
|
def test_tool_calling_with_content_blocks(self, model: BaseChatModelV1) -> None:
|
||||||
"""Test tool calling with content blocks."""
|
"""Test tool calling with content blocks."""
|
||||||
if not self.supports_enhanced_tool_calls:
|
if not self.has_tool_calling:
|
||||||
pytest.skip("Model does not support tool calls.")
|
pytest.skip("Model does not support tool calls.")
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
@ -20,6 +20,9 @@ from langchain_core.messages.content_blocks import (
|
|||||||
create_image_block,
|
create_image_block,
|
||||||
create_non_standard_block,
|
create_non_standard_block,
|
||||||
create_text_block,
|
create_text_block,
|
||||||
|
is_reasoning_block,
|
||||||
|
is_text_block,
|
||||||
|
is_tool_call_block,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.v1 import AIMessage, HumanMessage
|
from langchain_core.messages.v1 import AIMessage, HumanMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@ -55,7 +58,25 @@ class ChatModelV1Tests(BaseStandardTests):
|
|||||||
# Content Block Support Properties
|
# Content Block Support Properties
|
||||||
@property
|
@property
|
||||||
def supports_content_blocks_v1(self) -> bool:
|
def supports_content_blocks_v1(self) -> bool:
|
||||||
"""Whether the model supports content blocks v1 format."""
|
"""Whether the model supports content blocks v1 format.
|
||||||
|
|
||||||
|
Defualts to True. This should not be overridden by a ChatV1 subclass. You may
|
||||||
|
override the following properties to enable specific content block support.
|
||||||
|
Each defaults to False:
|
||||||
|
|
||||||
|
- ``supports_reasoning_content_blocks``
|
||||||
|
- ``supports_plaintext_content_blocks``
|
||||||
|
- ``supports_file_content_blocks``
|
||||||
|
- ``supports_image_content_blocks``
|
||||||
|
- ``supports_audio_content_blocks``
|
||||||
|
- ``supports_video_content_blocks``
|
||||||
|
- ``supports_citations``
|
||||||
|
- ``supports_web_search_blocks``
|
||||||
|
- ``supports_enhanced_tool_calls``
|
||||||
|
- ``supports_invalid_tool_calls``
|
||||||
|
- ``supports_tool_call_chunks``
|
||||||
|
|
||||||
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -65,7 +86,11 @@ class ChatModelV1Tests(BaseStandardTests):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_text_content_blocks(self) -> bool:
|
def supports_text_content_blocks(self) -> bool:
|
||||||
"""Whether the model supports ``TextContentBlock``."""
|
"""Whether the model supports ``TextContentBlock``.
|
||||||
|
|
||||||
|
This is a minimum requirement for v1 chat models.
|
||||||
|
|
||||||
|
"""
|
||||||
return self.supports_content_blocks_v1
|
return self.supports_content_blocks_v1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -108,21 +133,11 @@ class ChatModelV1Tests(BaseStandardTests):
|
|||||||
"""Whether the model supports ``WebSearchCall``/``WebSearchResult`` blocks."""
|
"""Whether the model supports ``WebSearchCall``/``WebSearchResult`` blocks."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_enhanced_tool_calls(self) -> bool:
|
|
||||||
"""Whether the model supports ``ToolCall`` format with content blocks."""
|
|
||||||
return self.has_tool_calling and self.supports_content_blocks_v1
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_invalid_tool_calls(self) -> bool:
|
def supports_invalid_tool_calls(self) -> bool:
|
||||||
"""Whether the model can handle ``InvalidToolCall`` blocks."""
|
"""Whether the model can handle ``InvalidToolCall`` blocks."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_tool_call_chunks(self) -> bool:
|
|
||||||
"""Whether the model supports streaming ``ToolCallChunk`` blocks."""
|
|
||||||
return self.supports_enhanced_tool_calls
|
|
||||||
|
|
||||||
|
|
||||||
class ChatModelV1UnitTests(ChatModelV1Tests):
|
class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||||
"""Unit tests for chat models with content blocks v1 support.
|
"""Unit tests for chat models with content blocks v1 support.
|
||||||
@ -289,7 +304,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
|||||||
``TextContentBlock`` objects instead of plain strings.
|
``TextContentBlock`` objects instead of plain strings.
|
||||||
"""
|
"""
|
||||||
if not self.supports_text_content_blocks:
|
if not self.supports_text_content_blocks:
|
||||||
pytest.skip("Model does not support TextContentBlock.")
|
pytest.skip("Model does not support TextContentBlock (rare!)")
|
||||||
|
|
||||||
text_block = create_text_block("Hello, world!")
|
text_block = create_text_block("Hello, world!")
|
||||||
message = HumanMessage(content=[text_block])
|
message = HumanMessage(content=[text_block])
|
||||||
@ -303,7 +318,9 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
|||||||
if not (
|
if not (
|
||||||
self.supports_text_content_blocks and self.supports_image_content_blocks
|
self.supports_text_content_blocks and self.supports_image_content_blocks
|
||||||
):
|
):
|
||||||
pytest.skip("Model does not support mixed content blocks.")
|
pytest.skip(
|
||||||
|
"Model doesn't support mixed content blocks (concurrent text and image)"
|
||||||
|
)
|
||||||
|
|
||||||
content_blocks: list[ContentBlock] = [
|
content_blocks: list[ContentBlock] = [
|
||||||
create_text_block("Describe this image:"),
|
create_text_block("Describe this image:"),
|
||||||
@ -332,7 +349,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
|||||||
reasoning_blocks = [
|
reasoning_blocks = [
|
||||||
block
|
block
|
||||||
for block in result.content
|
for block in result.content
|
||||||
if isinstance(block, dict) and block.get("type") == "reasoning"
|
if isinstance(block, dict) and is_reasoning_block(block)
|
||||||
]
|
]
|
||||||
assert len(reasoning_blocks) > 0
|
assert len(reasoning_blocks) > 0
|
||||||
|
|
||||||
@ -351,7 +368,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
|||||||
for block in content_list:
|
for block in content_list:
|
||||||
if (
|
if (
|
||||||
isinstance(block, dict)
|
isinstance(block, dict)
|
||||||
and block.get("type") == "text"
|
and is_text_block(block)
|
||||||
and "annotations" in block
|
and "annotations" in block
|
||||||
and isinstance(block.get("annotations"), list)
|
and isinstance(block.get("annotations"), list)
|
||||||
and len(cast(list, block.get("annotations", []))) > 0
|
and len(cast(list, block.get("annotations", []))) > 0
|
||||||
@ -394,7 +411,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
|||||||
self, model: BaseChatModelV1
|
self, model: BaseChatModelV1
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test enhanced tool calling with content blocks format."""
|
"""Test enhanced tool calling with content blocks format."""
|
||||||
if not self.supports_enhanced_tool_calls:
|
if not self.has_tool_calling:
|
||||||
pytest.skip("Model does not support enhanced tool calls.")
|
pytest.skip("Model does not support enhanced tool calls.")
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@ -413,7 +430,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
|||||||
tool_call_blocks = [
|
tool_call_blocks = [
|
||||||
block
|
block
|
||||||
for block in result.content
|
for block in result.content
|
||||||
if isinstance(block, dict) and block.get("type") == "tool_call"
|
if isinstance(block, dict) and is_tool_call_block(block)
|
||||||
]
|
]
|
||||||
assert len(tool_call_blocks) > 0
|
assert len(tool_call_blocks) > 0
|
||||||
# Backwards compat?
|
# Backwards compat?
|
||||||
|
Loading…
Reference in New Issue
Block a user