diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 692bcf50ecb..67c64fb1f17 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -9,7 +9,7 @@ import logging import typing import warnings from abc import ABC, abstractmethod -from collections.abc import Callable # noqa: TC003 +from collections.abc import Callable, Sequence from inspect import signature from typing import ( TYPE_CHECKING, @@ -69,7 +69,6 @@ from langchain_core.utils.pydantic import ( if TYPE_CHECKING: import uuid - from collections.abc import Sequence FILTERED_ARGS = ("run_manager", "callbacks") TOOL_MESSAGE_BLOCK_TYPES = ( @@ -398,12 +397,19 @@ class ToolException(Exception): # noqa: N818 ArgsSchema = TypeBaseModel | dict[str, Any] -ToolExceptionHandlerOutput = str | list[str | dict[str, Any]] +MessageContentBlock = str | dict[str, Any] +"""A single message content block: plain text or a structured block. + +A dict block is only considered valid at runtime when its `type` key is one of +`TOOL_MESSAGE_BLOCK_TYPES` (see `_is_message_content_block`); the static type +intentionally stays broad because block payloads vary by provider format. +""" +ToolExceptionHandlerOutput = str | Sequence[MessageContentBlock] """Content returned by a `handle_tool_error` callable. -Error handlers may return plain text or structured message content blocks. -When the original tool call includes a `tool_call_id`, this content is used -as the content of a `ToolMessage` with `status="error"`. +Error handlers may return plain text or a sequence of structured message +content blocks. When the original tool call includes a `tool_call_id`, this +content is normalized to the content of a `ToolMessage` with `status="error"`. """ _EMPTY_SET: frozenset[str] = frozenset() @@ -1302,8 +1308,8 @@ def _format_output( return content if isinstance(content, ToolOutputMixin) or tool_call_id is None: return content - if not _is_message_content_type(content): - content = _stringify(content) + normalized_content = _normalize_message_content(content) + content = _stringify(content) if normalized_content is None else normalized_content return ToolMessage( content, artifact=artifact, @@ -1313,20 +1319,28 @@ def _format_output( ) -def _is_message_content_type(obj: Any) -> bool: - """Check if object is valid message content format. +def _normalize_message_content(obj: Any) -> str | list[MessageContentBlock] | None: + """Coerce valid message content to the shape expected by `ToolMessage`. - Validates content for OpenAI or Anthropic format tool messages. + A string passes through unchanged; any `Sequence` of valid content blocks + (e.g. a list or tuple) is materialized into a `list`. Returning `None` + signals the caller (`_format_output`) that `obj` is not message content and + should be stringified instead. Args: - obj: The object to check. + obj: The object to normalize. Returns: - `True` if the object is valid message content, `False` otherwise. + The normalized content, or `None` if `obj` is not valid message content. """ - return isinstance(obj, str) or ( - isinstance(obj, list) and all(_is_message_content_block(e) for e in obj) - ) + if isinstance(obj, str): + return obj + # Validate lazily before materializing: `all` short-circuits on the first + # invalid element, so a large non-content sequence (e.g. `range(10**12)`) + # falls back to stringification without allocating it. + if isinstance(obj, Sequence) and all(_is_message_content_block(e) for e in obj): + return list(obj) + return None def _is_message_content_block(obj: Any) -> bool: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 9003ac8320e..274b25874f8 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -60,7 +60,7 @@ from langchain_core.tools.base import ( _DirectlyInjectedToolArg, _format_output, _is_message_content_block, - _is_message_content_type, + _normalize_message_content, get_all_basemodel_annotations, ) from langchain_core.utils.function_calling import ( @@ -827,9 +827,9 @@ def test_exception_handling_callable() -> None: def test_exception_handling_callable_message_content_blocks() -> None: - expected: list[str | dict[str, Any]] = [{"type": "text", "text": "handled error"}] + expected: list[dict[str, Any]] = [{"type": "text", "text": "handled error"}] - def handling(e: ToolException) -> list[str | dict[str, Any]]: + def handling(e: ToolException) -> list[dict[str, Any]]: return expected tool_ = _FakeExceptionTool(handle_tool_error=handling) @@ -843,6 +843,40 @@ def test_exception_handling_callable_message_content_blocks() -> None: assert actual.tool_call_id == "call_1" +def test_exception_handling_callable_message_content_blocks_sequence() -> None: + content = ({"type": "text", "text": "handled error"},) + + def handling(e: ToolException) -> tuple[dict[str, Any], ...]: + return content + + tool_ = _FakeExceptionTool(handle_tool_error=handling) + actual = tool_.invoke( + {"type": "tool_call", "args": {}, "name": "exception", "id": "call_1"} + ) + + assert isinstance(actual, ToolMessage) + assert actual.content == list(content) + assert actual.status == "error" + assert actual.tool_call_id == "call_1" + + +def test_exception_handling_callable_invalid_blocks_stringified() -> None: + # A sequence whose elements are not valid content blocks is not message + # content, so it falls back to a JSON-stringified ToolMessage. + def handling(e: ToolException) -> list[dict[str, Any]]: + return [{"text": "foo"}] # missing 'type' -> not a valid block + + tool_ = _FakeExceptionTool(handle_tool_error=handling) + actual = tool_.invoke( + {"type": "tool_call", "args": {}, "name": "exception", "id": "call_1"} + ) + + assert isinstance(actual, ToolMessage) + assert actual.content == '[{"text": "foo"}]' + assert actual.status == "error" + assert actual.tool_call_id == "call_1" + + def test_exception_handling_non_tool_exception() -> None: tool_ = _FakeExceptionTool(exception=ValueError("some error")) with pytest.raises(ValueError, match="some error"): @@ -875,9 +909,9 @@ async def test_async_exception_handling_callable() -> None: async def test_async_exception_handling_callable_message_content_blocks() -> None: - expected: list[str | dict[str, Any]] = [{"type": "text", "text": "handled error"}] + expected: list[dict[str, Any]] = [{"type": "text", "text": "handled error"}] - def handling(e: ToolException) -> list[str | dict[str, Any]]: + def handling(e: ToolException) -> list[dict[str, Any]]: return expected tool_ = _FakeExceptionTool(handle_tool_error=handling) @@ -891,6 +925,25 @@ async def test_async_exception_handling_callable_message_content_blocks() -> Non assert actual.tool_call_id == "call_1" +async def test_async_exception_handling_callable_message_content_blocks_sequence() -> ( + None +): + content = ({"type": "text", "text": "handled error"},) + + def handling(e: ToolException) -> tuple[dict[str, Any], ...]: + return content + + tool_ = _FakeExceptionTool(handle_tool_error=handling) + actual = await tool_.ainvoke( + {"type": "tool_call", "args": {}, "name": "exception", "id": "call_1"} + ) + + assert isinstance(actual, ToolMessage) + assert actual.content == list(content) + assert actual.status == "error" + assert actual.tool_call_id == "call_1" + + async def test_async_exception_handling_non_tool_exception() -> None: tool_ = _FakeExceptionTool(exception=ValueError("some error")) with pytest.raises(ValueError, match="some error"): @@ -2184,11 +2237,19 @@ def test__is_message_content_block(obj: Any, *, expected: bool) -> None: [ ("foo", True), (valid_tool_result_blocks, True), + (tuple(valid_tool_result_blocks), True), + ([], True), # empty sequences are vacuously valid content + ((), True), (invalid_tool_result_blocks, False), + (tuple(invalid_tool_result_blocks), False), + (({"type": "text", "text": "ok"}, {"text": "bad"}), False), # mixed + # Large non-content sequence: must reject lazily without materializing + # (would hang/OOM if validation allocated the sequence first). + (range(10**12), False), ], ) -def test__is_message_content_type(obj: Any, *, expected: bool) -> None: - assert _is_message_content_type(obj) is expected +def test_normalize_message_content_validity(obj: Any, *, expected: bool) -> None: + assert (_normalize_message_content(obj) is not None) is expected @pytest.mark.parametrize("use_v1_namespace", [True, False]) @@ -3750,11 +3811,12 @@ def test_format_output_list_with_non_mixin_element() -> None: def test_format_output_empty_list() -> None: - """An empty list falls through to stringify-and-wrap.""" + """An empty list is vacuously valid content and wrapped unchanged.""" result = _format_output( [], artifact=None, tool_call_id="0", name="t", status="success" ) assert isinstance(result, ToolMessage) + assert result.content == [] assert result.tool_call_id == "0"