This commit is contained in:
Christian Bromann
2026-03-31 10:11:46 -07:00
parent d9b22fe892
commit a21969543c
4 changed files with 69 additions and 49 deletions

View File

@@ -7,12 +7,7 @@ from langchain_core.tools import (
ToolException,
)
from langchain.tools.headless import (
HEADLESS_TOOL_METADATA_KEY,
HeadlessTool,
create_headless_tool,
tool,
)
from langchain.tools.headless import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolRuntime
__all__ = [
@@ -25,6 +20,5 @@ __all__ = [
"InjectedToolCallId",
"ToolException",
"ToolRuntime",
"create_headless_tool",
"tool",
]

View File

@@ -20,7 +20,10 @@ from langchain_core.utils.pydantic import is_basemodel_subclass, is_pydantic_v2_
from langgraph.types import interrupt
from pydantic import BaseModel, create_model
# Metadata flag for clients that need to detect headless tools (e.g. SDKs).
# Metadata on the tool definition for introspection (e.g. listing tools, bind_tools).
# This does not appear as LLM token stream chunks. When a headless tool runs inside a
# LangGraph graph, the interrupt value (see `_headless_interrupt_payload`) is what
# surfaces to clients during streamed graph execution (interrupt events).
HEADLESS_TOOL_METADATA_KEY = "headless_tool"
@@ -51,22 +54,37 @@ def _args_schema_with_injected_tool_call_id(
)
def _headless_interrupt_payload(tool_name: str, **kwargs: Any) -> Any:
"""Build the LangGraph interrupt value for a headless tool call."""
tool_call_id = kwargs.pop("tool_call_id", None)
return interrupt(
{
"type": "tool",
"tool_call": {
"id": tool_call_id,
"name": tool_name,
"args": kwargs,
},
}
)
def _make_headless_sync(tool_name: str) -> Callable[..., Any]:
def _headless_sync(
_config: RunnableConfig,
**kwargs: Any,
) -> Any:
return _headless_interrupt_payload(tool_name, **kwargs)
return _headless_sync
def _make_headless_coroutine(tool_name: str) -> Callable[..., Awaitable[Any]]:
async def _headless_coroutine(
_config: RunnableConfig,
**kwargs: Any,
) -> Any:
tool_call_id = kwargs.pop("tool_call_id", None)
return interrupt(
{
"type": "tool",
"tool_call": {
"id": tool_call_id,
"name": tool_name,
"args": kwargs,
},
}
)
return _headless_interrupt_payload(tool_name, **kwargs)
return _headless_coroutine
@@ -75,7 +93,7 @@ class HeadlessTool(StructuredTool):
"""Structured tool that interrupts instead of executing locally."""
def create_headless_tool(
def _create_headless_tool(
*,
name: str,
description: str,
@@ -84,18 +102,7 @@ def create_headless_tool(
response_format: Literal["content", "content_and_artifact"] = "content",
extras: dict[str, Any] | None = None,
) -> HeadlessTool:
"""Create a headless tool from a name, description, and argument schema.
Args:
name: Tool name exposed to the model.
description: Tool description exposed to the model.
args_schema: Pydantic model or JSON-schema dict for arguments.
return_direct: Whether the tool result should end the agent turn.
response_format: Same as `StructuredTool`.
extras: Optional provider-specific extras merged into tool extras.
Returns:
A `HeadlessTool` whose coroutine raises a LangGraph interrupt when run.
"""Instantiate a headless tool. Prefer the public `tool()` overload for new code.
Raises:
TypeError: If `args_schema` is not a Pydantic model or dict.
@@ -112,10 +119,11 @@ def create_headless_tool(
raise TypeError(msg)
metadata = {HEADLESS_TOOL_METADATA_KEY: True}
sync_fn = _make_headless_sync(name)
coroutine = _make_headless_coroutine(name)
return HeadlessTool(
name=name,
func=None,
func=sync_fn,
coroutine=coroutine,
description=description,
args_schema=schema_for_tool,
@@ -214,9 +222,10 @@ def tool(
) -> BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool] | HeadlessTool:
"""Create a tool, including headless (interrupting) tools.
If called with keyword-only `name`, `description`, and `args_schema` and no
implementation callable, returns a `HeadlessTool` that triggers a LangGraph
`interrupt` when executed. Otherwise delegates to `langchain_core.tools.tool`.
This is the supported entry point for headless tools: use keyword-only
`name`, `description`, and `args_schema` with no implementation callable to get
a `HeadlessTool` that calls LangGraph `interrupt` on both sync `invoke` and
async `ainvoke`. Otherwise delegates to `langchain_core.tools.tool`.
Args:
name_or_callable: Passed through to core `tool` when not using headless mode.
@@ -242,7 +251,7 @@ def tool(
and description is not None
and args_schema is not None
):
return create_headless_tool(
return _create_headless_tool(
name=name,
description=description,
args_schema=args_schema,

View File

@@ -9,20 +9,15 @@ from unittest.mock import patch
import pytest
from pydantic import BaseModel, Field
from langchain.tools import (
HEADLESS_TOOL_METADATA_KEY,
HeadlessTool,
create_headless_tool,
tool,
)
from langchain.tools import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool
class _MessageArgs(BaseModel):
message: str = Field(..., description="A message.")
def test_create_headless_tool_properties() -> None:
t = create_headless_tool(
def test_headless_tool_properties() -> None:
t = tool(
name="test_tool",
description="A test headless tool.",
args_schema=_MessageArgs,
@@ -55,7 +50,7 @@ def test_tool_normal_still_returns_structured_tool() -> None:
@pytest.mark.asyncio
async def test_headless_coroutine_calls_interrupt() -> None:
ht = create_headless_tool(
ht = tool(
name="interrupt_me",
description="d",
args_schema=_MessageArgs,
@@ -81,13 +76,36 @@ async def test_headless_coroutine_calls_interrupt() -> None:
assert getattr(result, "content", result) == "resumed"
def test_headless_sync_invoke_calls_interrupt() -> None:
"""Sync `invoke` must work (StructuredTool previously had no sync path)."""
ht = tool(
name="sync_interrupt",
description="d",
args_schema=_MessageArgs,
)
with patch("langchain.tools.headless.interrupt") as mock_interrupt:
mock_interrupt.return_value = "ok"
result = ht.invoke(
{
"type": "tool_call",
"name": "sync_interrupt",
"id": "cid-9",
"args": {"message": "sync"},
}
)
mock_interrupt.assert_called_once()
payload = mock_interrupt.call_args[0][0]
assert payload["tool_call"]["id"] == "cid-9"
assert getattr(result, "content", result) == "ok"
def test_headless_dict_schema_has_metadata() -> None:
schema: dict[str, Any] = {
"type": "object",
"properties": {"q": {"type": "string"}},
"required": ["q"],
}
ht = create_headless_tool(
ht = tool(
name="dict_tool",
description="Uses JSON schema.",
args_schema=schema,
@@ -97,7 +115,7 @@ def test_headless_dict_schema_has_metadata() -> None:
def test_invoke_without_graph_context_errors() -> None:
ht = create_headless_tool(
ht = tool(
name="t",
description="d",
args_schema=_MessageArgs,

View File

@@ -10,7 +10,6 @@ EXPECTED_ALL = {
"InjectedToolCallId",
"ToolException",
"ToolRuntime",
"create_headless_tool",
"tool",
}