mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 07:36:08 +00:00
refactoring
This commit is contained in:
parent
2f8470d7f2
commit
2290984cfa
@ -27,40 +27,44 @@ def test_structured_output(method: str) -> None:
|
||||
query = "Tell me a joke about cats."
|
||||
|
||||
# Pydantic
|
||||
structured_llm = llm.with_structured_output(Joke, method=method) # type: ignore[arg-type]
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, Joke)
|
||||
if method == "function_calling":
|
||||
structured_llm = llm.with_structured_output(Joke, method="function_calling")
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, Joke)
|
||||
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, Joke)
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, Joke)
|
||||
|
||||
# JSON Schema
|
||||
structured_llm = llm.with_structured_output(Joke.model_json_schema(), method=method) # type: ignore[arg-type]
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
if method == "json_schema":
|
||||
structured_llm = llm.with_structured_output(
|
||||
Joke.model_json_schema(), method="json_schema"
|
||||
)
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in structured_llm.stream(query):
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict)
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
||||
# Typed Dict
|
||||
class JokeSchema(TypedDict):
|
||||
"""Joke to tell user."""
|
||||
# Typed Dict
|
||||
class JokeSchema(TypedDict):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: Annotated[str, "question to set up a joke"]
|
||||
punchline: Annotated[str, "answer to resolve the joke"]
|
||||
setup: Annotated[str, "question to set up a joke"]
|
||||
punchline: Annotated[str, "answer to resolve the joke"]
|
||||
|
||||
structured_llm = llm.with_structured_output(JokeSchema, method=method) # type: ignore[arg-type]
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
structured_llm = llm.with_structured_output(JokeSchema, method="json_schema")
|
||||
result = structured_llm.invoke(query)
|
||||
assert isinstance(result, dict)
|
||||
assert set(result.keys()) == {"setup", "punchline"}
|
||||
|
||||
for chunk in structured_llm.stream(query):
|
||||
for chunk in structured_llm.stream(query):
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(chunk, dict)
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
assert set(chunk.keys()) == {"setup", "punchline"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
|
||||
|
@ -34,6 +34,9 @@ from langchain_core.messages.content_blocks import (
|
||||
create_plaintext_block,
|
||||
create_text_block,
|
||||
create_video_block,
|
||||
is_reasoning_block,
|
||||
is_text_block,
|
||||
is_tool_call_block,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage, AIMessageChunk, HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
@ -90,7 +93,7 @@ def _validate_tool_call_message(message: AIMessage) -> None:
|
||||
tool_call_blocks = [
|
||||
block
|
||||
for block in message.content
|
||||
if isinstance(block, dict) and block.get("type") == "tool_call"
|
||||
if isinstance(block, dict) and is_tool_call_block(block)
|
||||
]
|
||||
assert len(tool_call_blocks) >= 1
|
||||
|
||||
@ -170,7 +173,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
||||
text_blocks = [
|
||||
block
|
||||
for block in result.content
|
||||
if isinstance(block, dict) and block.get("type") == "text"
|
||||
if isinstance(block, dict) and is_text_block(block)
|
||||
]
|
||||
assert len(text_blocks) > 0
|
||||
if result.text:
|
||||
@ -244,7 +247,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
||||
reasoning_blocks = [
|
||||
block
|
||||
for block in result.content
|
||||
if isinstance(block, dict) and block.get("type") == "reasoning"
|
||||
if isinstance(block, dict) and is_reasoning_block(block)
|
||||
]
|
||||
assert len(reasoning_blocks) > 0
|
||||
|
||||
@ -266,7 +269,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
||||
for block in result.content:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "text"
|
||||
and is_text_block(block)
|
||||
and "annotations" in block
|
||||
):
|
||||
annotations = cast("list[dict[str, Any]]", block.get("annotations", []))
|
||||
@ -343,7 +346,7 @@ class ChatModelV1IntegrationTests(ChatModelV1Tests):
|
||||
|
||||
def test_tool_calling_with_content_blocks(self, model: BaseChatModelV1) -> None:
|
||||
"""Test tool calling with content blocks."""
|
||||
if not self.supports_enhanced_tool_calls:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Model does not support tool calls.")
|
||||
|
||||
@tool
|
||||
|
@ -20,6 +20,9 @@ from langchain_core.messages.content_blocks import (
|
||||
create_image_block,
|
||||
create_non_standard_block,
|
||||
create_text_block,
|
||||
is_reasoning_block,
|
||||
is_text_block,
|
||||
is_tool_call_block,
|
||||
)
|
||||
from langchain_core.messages.v1 import AIMessage, HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
@ -55,7 +58,25 @@ class ChatModelV1Tests(BaseStandardTests):
|
||||
# Content Block Support Properties
|
||||
@property
|
||||
def supports_content_blocks_v1(self) -> bool:
|
||||
"""Whether the model supports content blocks v1 format."""
|
||||
"""Whether the model supports content blocks v1 format.
|
||||
|
||||
Defualts to True. This should not be overridden by a ChatV1 subclass. You may
|
||||
override the following properties to enable specific content block support.
|
||||
Each defaults to False:
|
||||
|
||||
- ``supports_reasoning_content_blocks``
|
||||
- ``supports_plaintext_content_blocks``
|
||||
- ``supports_file_content_blocks``
|
||||
- ``supports_image_content_blocks``
|
||||
- ``supports_audio_content_blocks``
|
||||
- ``supports_video_content_blocks``
|
||||
- ``supports_citations``
|
||||
- ``supports_web_search_blocks``
|
||||
- ``supports_enhanced_tool_calls``
|
||||
- ``supports_invalid_tool_calls``
|
||||
- ``supports_tool_call_chunks``
|
||||
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
@ -65,7 +86,11 @@ class ChatModelV1Tests(BaseStandardTests):
|
||||
|
||||
@property
|
||||
def supports_text_content_blocks(self) -> bool:
|
||||
"""Whether the model supports ``TextContentBlock``."""
|
||||
"""Whether the model supports ``TextContentBlock``.
|
||||
|
||||
This is a minimum requirement for v1 chat models.
|
||||
|
||||
"""
|
||||
return self.supports_content_blocks_v1
|
||||
|
||||
@property
|
||||
@ -108,21 +133,11 @@ class ChatModelV1Tests(BaseStandardTests):
|
||||
"""Whether the model supports ``WebSearchCall``/``WebSearchResult`` blocks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_enhanced_tool_calls(self) -> bool:
|
||||
"""Whether the model supports ``ToolCall`` format with content blocks."""
|
||||
return self.has_tool_calling and self.supports_content_blocks_v1
|
||||
|
||||
@property
|
||||
def supports_invalid_tool_calls(self) -> bool:
|
||||
"""Whether the model can handle ``InvalidToolCall`` blocks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_tool_call_chunks(self) -> bool:
|
||||
"""Whether the model supports streaming ``ToolCallChunk`` blocks."""
|
||||
return self.supports_enhanced_tool_calls
|
||||
|
||||
|
||||
class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
"""Unit tests for chat models with content blocks v1 support.
|
||||
@ -289,7 +304,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
``TextContentBlock`` objects instead of plain strings.
|
||||
"""
|
||||
if not self.supports_text_content_blocks:
|
||||
pytest.skip("Model does not support TextContentBlock.")
|
||||
pytest.skip("Model does not support TextContentBlock (rare!)")
|
||||
|
||||
text_block = create_text_block("Hello, world!")
|
||||
message = HumanMessage(content=[text_block])
|
||||
@ -303,7 +318,9 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
if not (
|
||||
self.supports_text_content_blocks and self.supports_image_content_blocks
|
||||
):
|
||||
pytest.skip("Model does not support mixed content blocks.")
|
||||
pytest.skip(
|
||||
"Model doesn't support mixed content blocks (concurrent text and image)"
|
||||
)
|
||||
|
||||
content_blocks: list[ContentBlock] = [
|
||||
create_text_block("Describe this image:"),
|
||||
@ -332,7 +349,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
reasoning_blocks = [
|
||||
block
|
||||
for block in result.content
|
||||
if isinstance(block, dict) and block.get("type") == "reasoning"
|
||||
if isinstance(block, dict) and is_reasoning_block(block)
|
||||
]
|
||||
assert len(reasoning_blocks) > 0
|
||||
|
||||
@ -351,7 +368,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
for block in content_list:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "text"
|
||||
and is_text_block(block)
|
||||
and "annotations" in block
|
||||
and isinstance(block.get("annotations"), list)
|
||||
and len(cast(list, block.get("annotations", []))) > 0
|
||||
@ -394,7 +411,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
self, model: BaseChatModelV1
|
||||
) -> None:
|
||||
"""Test enhanced tool calling with content blocks format."""
|
||||
if not self.supports_enhanced_tool_calls:
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Model does not support enhanced tool calls.")
|
||||
|
||||
@tool
|
||||
@ -413,7 +430,7 @@ class ChatModelV1UnitTests(ChatModelV1Tests):
|
||||
tool_call_blocks = [
|
||||
block
|
||||
for block in result.content
|
||||
if isinstance(block, dict) and block.get("type") == "tool_call"
|
||||
if isinstance(block, dict) and is_tool_call_block(block)
|
||||
]
|
||||
assert len(tool_call_blocks) > 0
|
||||
# Backwards compat?
|
||||
|
Loading…
Reference in New Issue
Block a user