mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(core): allow _format_output to pass through list of ToolOutputMixin instances (#36963)
This commit is contained in:
@@ -1265,8 +1265,15 @@ def _format_output(
|
|||||||
status: The execution status.
|
status: The execution status.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The formatted output, either as a `ToolMessage` or the original content.
|
The formatted output, either as a `ToolMessage`, the original content,
|
||||||
|
or an unchanged list of `ToolOutputMixin` instances.
|
||||||
"""
|
"""
|
||||||
|
if (
|
||||||
|
isinstance(content, list)
|
||||||
|
and content
|
||||||
|
and all(isinstance(item, ToolOutputMixin) for item in content)
|
||||||
|
):
|
||||||
|
return content
|
||||||
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
if isinstance(content, ToolOutputMixin) or tool_call_id is None:
|
||||||
return content
|
return content
|
||||||
if not _is_message_content_type(content):
|
if not _is_message_content_type(content):
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ from langchain_core.tools.base import (
|
|||||||
InjectedToolCallId,
|
InjectedToolCallId,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
_DirectlyInjectedToolArg,
|
_DirectlyInjectedToolArg,
|
||||||
|
_format_output,
|
||||||
_is_message_content_block,
|
_is_message_content_block,
|
||||||
_is_message_content_type,
|
_is_message_content_type,
|
||||||
get_all_basemodel_annotations,
|
get_all_basemodel_annotations,
|
||||||
@@ -128,6 +129,22 @@ class _MockStructuredTool(BaseTool):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeOutput(ToolOutputMixin):
|
||||||
|
"""Minimal ToolOutputMixin subclass used only in tests."""
|
||||||
|
|
||||||
|
def __init__(self, value: int) -> None:
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
return isinstance(other, _FakeOutput) and self.value == other.value
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.value)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"_FakeOutput({self.value})"
|
||||||
|
|
||||||
|
|
||||||
def test_structured_args() -> None:
|
def test_structured_args() -> None:
|
||||||
"""Test functionality with structured arguments."""
|
"""Test functionality with structured arguments."""
|
||||||
structured_api = _MockStructuredTool()
|
structured_api = _MockStructuredTool()
|
||||||
@@ -3653,3 +3670,74 @@ def test_tool_default_factory_not_required() -> None:
|
|||||||
schema = convert_to_openai_tool(some_func)
|
schema = convert_to_openai_tool(some_func)
|
||||||
params = schema["function"]["parameters"]
|
params = schema["function"]["parameters"]
|
||||||
assert "names" not in params.get("required", [])
|
assert "names" not in params.get("required", [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_output_list_of_tool_messages() -> None:
|
||||||
|
"""A list of ToolMessages passes through unchanged."""
|
||||||
|
msgs = [
|
||||||
|
ToolMessage("a", tool_call_id="1", name="t"),
|
||||||
|
ToolMessage("b", tool_call_id="2", name="t"),
|
||||||
|
]
|
||||||
|
result = _format_output(
|
||||||
|
msgs, artifact=None, tool_call_id="0", name="t", status="success"
|
||||||
|
)
|
||||||
|
assert result is msgs
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_output_list_of_custom_mixin_instances() -> None:
|
||||||
|
"""A list of custom ToolOutputMixin subclass instances passes through."""
|
||||||
|
items = [_FakeOutput(1), _FakeOutput(2)]
|
||||||
|
result = _format_output(
|
||||||
|
items, artifact=None, tool_call_id="0", name="t", status="success"
|
||||||
|
)
|
||||||
|
assert result is items
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_output_mixed_mixin_subclasses() -> None:
|
||||||
|
"""A list mixing ToolMessage and custom ToolOutputMixin passes through."""
|
||||||
|
items: list[ToolOutputMixin] = [
|
||||||
|
ToolMessage("a", tool_call_id="1", name="t"),
|
||||||
|
_FakeOutput(42),
|
||||||
|
]
|
||||||
|
result = _format_output(
|
||||||
|
items, artifact=None, tool_call_id="0", name="t", status="success"
|
||||||
|
)
|
||||||
|
assert result is items
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_output_list_with_non_mixin_element() -> None:
|
||||||
|
"""A list containing a non-ToolOutputMixin falls through to stringify."""
|
||||||
|
items = [ToolMessage("a", tool_call_id="1", name="t"), "oops"]
|
||||||
|
result = _format_output(
|
||||||
|
items, artifact=None, tool_call_id="0", name="t", status="success"
|
||||||
|
)
|
||||||
|
assert isinstance(result, ToolMessage)
|
||||||
|
assert result.tool_call_id == "0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_output_empty_list() -> None:
|
||||||
|
"""An empty list falls through to stringify-and-wrap."""
|
||||||
|
result = _format_output(
|
||||||
|
[], artifact=None, tool_call_id="0", name="t", status="success"
|
||||||
|
)
|
||||||
|
assert isinstance(result, ToolMessage)
|
||||||
|
assert result.tool_call_id == "0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_invoke_returns_list_of_mixin() -> None:
|
||||||
|
"""End-to-end: a tool returning a list of ToolOutputMixin via invoke."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def multi(x: int) -> list:
|
||||||
|
"""Return multiple outputs."""
|
||||||
|
return [
|
||||||
|
ToolMessage(f"result-{i}", tool_call_id=f"sub-{i}", name="multi")
|
||||||
|
for i in range(x)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = multi.invoke(
|
||||||
|
{"type": "tool_call", "args": {"x": 3}, "name": "multi", "id": "outer"}
|
||||||
|
)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert all(isinstance(m, ToolMessage) for m in result)
|
||||||
|
|||||||
Reference in New Issue
Block a user