feat(core): allow _format_output to pass through list of ToolOutputMixin instances (#36963)

This commit is contained in:
Hunter Lovell
2026-04-23 10:49:46 -07:00
committed by GitHub
parent bb77a4229f
commit 9a671d7919
2 changed files with 96 additions and 1 deletions

View File

@@ -58,6 +58,7 @@ from langchain_core.tools.base import (
InjectedToolCallId,
SchemaAnnotationError,
_DirectlyInjectedToolArg,
_format_output,
_is_message_content_block,
_is_message_content_type,
get_all_basemodel_annotations,
@@ -128,6 +129,22 @@ class _MockStructuredTool(BaseTool):
raise NotImplementedError
class _FakeOutput(ToolOutputMixin):
"""Minimal ToolOutputMixin subclass used only in tests."""
def __init__(self, value: int) -> None:
self.value = value
def __eq__(self, other: object) -> bool:
return isinstance(other, _FakeOutput) and self.value == other.value
def __hash__(self) -> int:
return hash(self.value)
def __repr__(self) -> str:
return f"_FakeOutput({self.value})"
def test_structured_args() -> None:
"""Test functionality with structured arguments."""
structured_api = _MockStructuredTool()
@@ -3653,3 +3670,74 @@ def test_tool_default_factory_not_required() -> None:
schema = convert_to_openai_tool(some_func)
params = schema["function"]["parameters"]
assert "names" not in params.get("required", [])
def test_format_output_list_of_tool_messages() -> None:
"""A list of ToolMessages passes through unchanged."""
msgs = [
ToolMessage("a", tool_call_id="1", name="t"),
ToolMessage("b", tool_call_id="2", name="t"),
]
result = _format_output(
msgs, artifact=None, tool_call_id="0", name="t", status="success"
)
assert result is msgs
def test_format_output_list_of_custom_mixin_instances() -> None:
"""A list of custom ToolOutputMixin subclass instances passes through."""
items = [_FakeOutput(1), _FakeOutput(2)]
result = _format_output(
items, artifact=None, tool_call_id="0", name="t", status="success"
)
assert result is items
def test_format_output_mixed_mixin_subclasses() -> None:
"""A list mixing ToolMessage and custom ToolOutputMixin passes through."""
items: list[ToolOutputMixin] = [
ToolMessage("a", tool_call_id="1", name="t"),
_FakeOutput(42),
]
result = _format_output(
items, artifact=None, tool_call_id="0", name="t", status="success"
)
assert result is items
def test_format_output_list_with_non_mixin_element() -> None:
"""A list containing a non-ToolOutputMixin falls through to stringify."""
items = [ToolMessage("a", tool_call_id="1", name="t"), "oops"]
result = _format_output(
items, artifact=None, tool_call_id="0", name="t", status="success"
)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "0"
def test_format_output_empty_list() -> None:
"""An empty list falls through to stringify-and-wrap."""
result = _format_output(
[], artifact=None, tool_call_id="0", name="t", status="success"
)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "0"
def test_tool_invoke_returns_list_of_mixin() -> None:
"""End-to-end: a tool returning a list of ToolOutputMixin via invoke."""
@tool
def multi(x: int) -> list:
"""Return multiple outputs."""
return [
ToolMessage(f"result-{i}", tool_call_id=f"sub-{i}", name="multi")
for i in range(x)
]
result = multi.invoke(
{"type": "tool_call", "args": {"x": 3}, "name": "multi", "id": "outer"}
)
assert isinstance(result, list)
assert len(result) == 3
assert all(isinstance(m, ToolMessage) for m in result)