mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 06:42:37 +00:00
fix(core): accept sequence tool error content (#38005)
`handle_tool_error` callables can now return structured message content as any valid sequence, not just a mutable `list`. Valid structured sequences are normalized to the `ToolMessage` content shape at the tool output boundary, while invalid content still falls back to stringification. ## Changes - Widened `ToolExceptionHandlerOutput` from `list[str | dict[str, Any]]` to `Sequence[MessageContentBlock]` so handlers returning `list[dict[str, Any]]` or tuple content blocks type-check cleanly. - Added `_normalize_message_content` to validate structured message content and convert valid non-string sequences to the `list` shape expected by `ToolMessage`. - Preserved existing stringification behavior for invalid structured content blocks instead of treating failed normalization as `None`. - Removed the now-unused `_is_message_content_type` helper; output formatting validates content directly through `_normalize_message_content`.
This commit is contained in:
@@ -9,7 +9,7 @@ import logging
|
||||
import typing
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable # noqa: TC003
|
||||
from collections.abc import Callable, Sequence
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -69,7 +69,6 @@ from langchain_core.utils.pydantic import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
FILTERED_ARGS = ("run_manager", "callbacks")
|
||||
TOOL_MESSAGE_BLOCK_TYPES = (
|
||||
@@ -398,12 +397,19 @@ class ToolException(Exception): # noqa: N818
|
||||
|
||||
|
||||
ArgsSchema = TypeBaseModel | dict[str, Any]
|
||||
ToolExceptionHandlerOutput = str | list[str | dict[str, Any]]
|
||||
MessageContentBlock = str | dict[str, Any]
|
||||
"""A single message content block: plain text or a structured block.
|
||||
|
||||
A dict block is only considered valid at runtime when its `type` key is one of
|
||||
`TOOL_MESSAGE_BLOCK_TYPES` (see `_is_message_content_block`); the static type
|
||||
intentionally stays broad because block payloads vary by provider format.
|
||||
"""
|
||||
ToolExceptionHandlerOutput = str | Sequence[MessageContentBlock]
|
||||
"""Content returned by a `handle_tool_error` callable.
|
||||
|
||||
Error handlers may return plain text or structured message content blocks.
|
||||
When the original tool call includes a `tool_call_id`, this content is used
|
||||
as the content of a `ToolMessage` with `status="error"`.
|
||||
Error handlers may return plain text or a sequence of structured message
|
||||
content blocks. When the original tool call includes a `tool_call_id`, this
|
||||
content is normalized to the content of a `ToolMessage` with `status="error"`.
|
||||
"""
|
||||
|
||||
_EMPTY_SET: frozenset[str] = frozenset()
|
||||
@@ -1302,8 +1308,8 @@ def _format_output(
|
||||
return content
|
||||
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
||||
return content
|
||||
if not _is_message_content_type(content):
|
||||
content = _stringify(content)
|
||||
normalized_content = _normalize_message_content(content)
|
||||
content = _stringify(content) if normalized_content is None else normalized_content
|
||||
return ToolMessage(
|
||||
content,
|
||||
artifact=artifact,
|
||||
@@ -1313,20 +1319,28 @@ def _format_output(
|
||||
)
|
||||
|
||||
|
||||
def _is_message_content_type(obj: Any) -> bool:
|
||||
"""Check if object is valid message content format.
|
||||
def _normalize_message_content(obj: Any) -> str | list[MessageContentBlock] | None:
|
||||
"""Coerce valid message content to the shape expected by `ToolMessage`.
|
||||
|
||||
Validates content for OpenAI or Anthropic format tool messages.
|
||||
A string passes through unchanged; any `Sequence` of valid content blocks
|
||||
(e.g. a list or tuple) is materialized into a `list`. Returning `None`
|
||||
signals the caller (`_format_output`) that `obj` is not message content and
|
||||
should be stringified instead.
|
||||
|
||||
Args:
|
||||
obj: The object to check.
|
||||
obj: The object to normalize.
|
||||
|
||||
Returns:
|
||||
`True` if the object is valid message content, `False` otherwise.
|
||||
The normalized content, or `None` if `obj` is not valid message content.
|
||||
"""
|
||||
return isinstance(obj, str) or (
|
||||
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
|
||||
)
|
||||
if isinstance(obj, str):
|
||||
return obj
|
||||
# Validate lazily before materializing: `all` short-circuits on the first
|
||||
# invalid element, so a large non-content sequence (e.g. `range(10**12)`)
|
||||
# falls back to stringification without allocating it.
|
||||
if isinstance(obj, Sequence) and all(_is_message_content_block(e) for e in obj):
|
||||
return list(obj)
|
||||
return None
|
||||
|
||||
|
||||
def _is_message_content_block(obj: Any) -> bool:
|
||||
|
||||
@@ -60,7 +60,7 @@ from langchain_core.tools.base import (
|
||||
_DirectlyInjectedToolArg,
|
||||
_format_output,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
_normalize_message_content,
|
||||
get_all_basemodel_annotations,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
@@ -827,9 +827,9 @@ def test_exception_handling_callable() -> None:
|
||||
|
||||
|
||||
def test_exception_handling_callable_message_content_blocks() -> None:
|
||||
expected: list[str | dict[str, Any]] = [{"type": "text", "text": "handled error"}]
|
||||
expected: list[dict[str, Any]] = [{"type": "text", "text": "handled error"}]
|
||||
|
||||
def handling(e: ToolException) -> list[str | dict[str, Any]]:
|
||||
def handling(e: ToolException) -> list[dict[str, Any]]:
|
||||
return expected
|
||||
|
||||
tool_ = _FakeExceptionTool(handle_tool_error=handling)
|
||||
@@ -843,6 +843,40 @@ def test_exception_handling_callable_message_content_blocks() -> None:
|
||||
assert actual.tool_call_id == "call_1"
|
||||
|
||||
|
||||
def test_exception_handling_callable_message_content_blocks_sequence() -> None:
|
||||
content = ({"type": "text", "text": "handled error"},)
|
||||
|
||||
def handling(e: ToolException) -> tuple[dict[str, Any], ...]:
|
||||
return content
|
||||
|
||||
tool_ = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = tool_.invoke(
|
||||
{"type": "tool_call", "args": {}, "name": "exception", "id": "call_1"}
|
||||
)
|
||||
|
||||
assert isinstance(actual, ToolMessage)
|
||||
assert actual.content == list(content)
|
||||
assert actual.status == "error"
|
||||
assert actual.tool_call_id == "call_1"
|
||||
|
||||
|
||||
def test_exception_handling_callable_invalid_blocks_stringified() -> None:
|
||||
# A sequence whose elements are not valid content blocks is not message
|
||||
# content, so it falls back to a JSON-stringified ToolMessage.
|
||||
def handling(e: ToolException) -> list[dict[str, Any]]:
|
||||
return [{"text": "foo"}] # missing 'type' -> not a valid block
|
||||
|
||||
tool_ = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = tool_.invoke(
|
||||
{"type": "tool_call", "args": {}, "name": "exception", "id": "call_1"}
|
||||
)
|
||||
|
||||
assert isinstance(actual, ToolMessage)
|
||||
assert actual.content == '[{"text": "foo"}]'
|
||||
assert actual.status == "error"
|
||||
assert actual.tool_call_id == "call_1"
|
||||
|
||||
|
||||
def test_exception_handling_non_tool_exception() -> None:
|
||||
tool_ = _FakeExceptionTool(exception=ValueError("some error"))
|
||||
with pytest.raises(ValueError, match="some error"):
|
||||
@@ -875,9 +909,9 @@ async def test_async_exception_handling_callable() -> None:
|
||||
|
||||
|
||||
async def test_async_exception_handling_callable_message_content_blocks() -> None:
|
||||
expected: list[str | dict[str, Any]] = [{"type": "text", "text": "handled error"}]
|
||||
expected: list[dict[str, Any]] = [{"type": "text", "text": "handled error"}]
|
||||
|
||||
def handling(e: ToolException) -> list[str | dict[str, Any]]:
|
||||
def handling(e: ToolException) -> list[dict[str, Any]]:
|
||||
return expected
|
||||
|
||||
tool_ = _FakeExceptionTool(handle_tool_error=handling)
|
||||
@@ -891,6 +925,25 @@ async def test_async_exception_handling_callable_message_content_blocks() -> Non
|
||||
assert actual.tool_call_id == "call_1"
|
||||
|
||||
|
||||
async def test_async_exception_handling_callable_message_content_blocks_sequence() -> (
|
||||
None
|
||||
):
|
||||
content = ({"type": "text", "text": "handled error"},)
|
||||
|
||||
def handling(e: ToolException) -> tuple[dict[str, Any], ...]:
|
||||
return content
|
||||
|
||||
tool_ = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = await tool_.ainvoke(
|
||||
{"type": "tool_call", "args": {}, "name": "exception", "id": "call_1"}
|
||||
)
|
||||
|
||||
assert isinstance(actual, ToolMessage)
|
||||
assert actual.content == list(content)
|
||||
assert actual.status == "error"
|
||||
assert actual.tool_call_id == "call_1"
|
||||
|
||||
|
||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||
tool_ = _FakeExceptionTool(exception=ValueError("some error"))
|
||||
with pytest.raises(ValueError, match="some error"):
|
||||
@@ -2184,11 +2237,19 @@ def test__is_message_content_block(obj: Any, *, expected: bool) -> None:
|
||||
[
|
||||
("foo", True),
|
||||
(valid_tool_result_blocks, True),
|
||||
(tuple(valid_tool_result_blocks), True),
|
||||
([], True), # empty sequences are vacuously valid content
|
||||
((), True),
|
||||
(invalid_tool_result_blocks, False),
|
||||
(tuple(invalid_tool_result_blocks), False),
|
||||
(({"type": "text", "text": "ok"}, {"text": "bad"}), False), # mixed
|
||||
# Large non-content sequence: must reject lazily without materializing
|
||||
# (would hang/OOM if validation allocated the sequence first).
|
||||
(range(10**12), False),
|
||||
],
|
||||
)
|
||||
def test__is_message_content_type(obj: Any, *, expected: bool) -> None:
|
||||
assert _is_message_content_type(obj) is expected
|
||||
def test_normalize_message_content_validity(obj: Any, *, expected: bool) -> None:
|
||||
assert (_normalize_message_content(obj) is not None) is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_v1_namespace", [True, False])
|
||||
@@ -3750,11 +3811,12 @@ def test_format_output_list_with_non_mixin_element() -> None:
|
||||
|
||||
|
||||
def test_format_output_empty_list() -> None:
|
||||
"""An empty list falls through to stringify-and-wrap."""
|
||||
"""An empty list is vacuously valid content and wrapped unchanged."""
|
||||
result = _format_output(
|
||||
[], artifact=None, tool_call_id="0", name="t", status="success"
|
||||
)
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.content == []
|
||||
assert result.tool_call_id == "0"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user