From a21969543c388ba9d926e73ec4d344b23a137610 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Tue, 31 Mar 2026 10:11:46 -0700 Subject: [PATCH] cr --- libs/langchain_v1/langchain/tools/__init__.py | 8 +-- libs/langchain_v1/langchain/tools/headless.py | 69 +++++++++++-------- .../tests/unit_tests/tools/test_headless.py | 40 ++++++++--- .../tests/unit_tests/tools/test_imports.py | 1 - 4 files changed, 69 insertions(+), 49 deletions(-) diff --git a/libs/langchain_v1/langchain/tools/__init__.py b/libs/langchain_v1/langchain/tools/__init__.py index 1ef11cc5622..de2fb16d991 100644 --- a/libs/langchain_v1/langchain/tools/__init__.py +++ b/libs/langchain_v1/langchain/tools/__init__.py @@ -7,12 +7,7 @@ from langchain_core.tools import ( ToolException, ) -from langchain.tools.headless import ( - HEADLESS_TOOL_METADATA_KEY, - HeadlessTool, - create_headless_tool, - tool, -) +from langchain.tools.headless import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool from langchain.tools.tool_node import InjectedState, InjectedStore, ToolRuntime __all__ = [ @@ -25,6 +20,5 @@ __all__ = [ "InjectedToolCallId", "ToolException", "ToolRuntime", - "create_headless_tool", "tool", ] diff --git a/libs/langchain_v1/langchain/tools/headless.py b/libs/langchain_v1/langchain/tools/headless.py index 836012d1842..26e60f9f836 100644 --- a/libs/langchain_v1/langchain/tools/headless.py +++ b/libs/langchain_v1/langchain/tools/headless.py @@ -20,7 +20,10 @@ from langchain_core.utils.pydantic import is_basemodel_subclass, is_pydantic_v2_ from langgraph.types import interrupt from pydantic import BaseModel, create_model -# Metadata flag for clients that need to detect headless tools (e.g. SDKs). +# Metadata on the tool definition for introspection (e.g. listing tools, bind_tools). +# This does not appear as LLM token stream chunks. When a headless tool runs inside a +# LangGraph graph, the interrupt value (see `_headless_interrupt_payload`) is what +# surfaces to clients during streamed graph execution (interrupt events). HEADLESS_TOOL_METADATA_KEY = "headless_tool" @@ -51,22 +54,37 @@ def _args_schema_with_injected_tool_call_id( ) +def _headless_interrupt_payload(tool_name: str, **kwargs: Any) -> Any: + """Build the LangGraph interrupt value for a headless tool call.""" + tool_call_id = kwargs.pop("tool_call_id", None) + return interrupt( + { + "type": "tool", + "tool_call": { + "id": tool_call_id, + "name": tool_name, + "args": kwargs, + }, + } + ) + + +def _make_headless_sync(tool_name: str) -> Callable[..., Any]: + def _headless_sync( + _config: RunnableConfig, + **kwargs: Any, + ) -> Any: + return _headless_interrupt_payload(tool_name, **kwargs) + + return _headless_sync + + def _make_headless_coroutine(tool_name: str) -> Callable[..., Awaitable[Any]]: async def _headless_coroutine( _config: RunnableConfig, **kwargs: Any, ) -> Any: - tool_call_id = kwargs.pop("tool_call_id", None) - return interrupt( - { - "type": "tool", - "tool_call": { - "id": tool_call_id, - "name": tool_name, - "args": kwargs, - }, - } - ) + return _headless_interrupt_payload(tool_name, **kwargs) return _headless_coroutine @@ -75,7 +93,7 @@ class HeadlessTool(StructuredTool): """Structured tool that interrupts instead of executing locally.""" -def create_headless_tool( +def _create_headless_tool( *, name: str, description: str, @@ -84,18 +102,7 @@ def create_headless_tool( response_format: Literal["content", "content_and_artifact"] = "content", extras: dict[str, Any] | None = None, ) -> HeadlessTool: - """Create a headless tool from a name, description, and argument schema. - - Args: - name: Tool name exposed to the model. - description: Tool description exposed to the model. - args_schema: Pydantic model or JSON-schema dict for arguments. - return_direct: Whether the tool result should end the agent turn. - response_format: Same as `StructuredTool`. - extras: Optional provider-specific extras merged into tool extras. - - Returns: - A `HeadlessTool` whose coroutine raises a LangGraph interrupt when run. + """Instantiate a headless tool. Prefer the public `tool()` overload for new code. Raises: TypeError: If `args_schema` is not a Pydantic model or dict. @@ -112,10 +119,11 @@ def create_headless_tool( raise TypeError(msg) metadata = {HEADLESS_TOOL_METADATA_KEY: True} + sync_fn = _make_headless_sync(name) coroutine = _make_headless_coroutine(name) return HeadlessTool( name=name, - func=None, + func=sync_fn, coroutine=coroutine, description=description, args_schema=schema_for_tool, @@ -214,9 +222,10 @@ def tool( ) -> BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool] | HeadlessTool: """Create a tool, including headless (interrupting) tools. - If called with keyword-only `name`, `description`, and `args_schema` and no - implementation callable, returns a `HeadlessTool` that triggers a LangGraph - `interrupt` when executed. Otherwise delegates to `langchain_core.tools.tool`. + This is the supported entry point for headless tools: use keyword-only + `name`, `description`, and `args_schema` with no implementation callable to get + a `HeadlessTool` that calls LangGraph `interrupt` on both sync `invoke` and + async `ainvoke`. Otherwise delegates to `langchain_core.tools.tool`. Args: name_or_callable: Passed through to core `tool` when not using headless mode. @@ -242,7 +251,7 @@ def tool( and description is not None and args_schema is not None ): - return create_headless_tool( + return _create_headless_tool( name=name, description=description, args_schema=args_schema, diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_headless.py b/libs/langchain_v1/tests/unit_tests/tools/test_headless.py index 98fb372f1f8..3d2416a2b7c 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_headless.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_headless.py @@ -9,20 +9,15 @@ from unittest.mock import patch import pytest from pydantic import BaseModel, Field -from langchain.tools import ( - HEADLESS_TOOL_METADATA_KEY, - HeadlessTool, - create_headless_tool, - tool, -) +from langchain.tools import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool class _MessageArgs(BaseModel): message: str = Field(..., description="A message.") -def test_create_headless_tool_properties() -> None: - t = create_headless_tool( +def test_headless_tool_properties() -> None: + t = tool( name="test_tool", description="A test headless tool.", args_schema=_MessageArgs, @@ -55,7 +50,7 @@ def test_tool_normal_still_returns_structured_tool() -> None: @pytest.mark.asyncio async def test_headless_coroutine_calls_interrupt() -> None: - ht = create_headless_tool( + ht = tool( name="interrupt_me", description="d", args_schema=_MessageArgs, @@ -81,13 +76,36 @@ async def test_headless_coroutine_calls_interrupt() -> None: assert getattr(result, "content", result) == "resumed" +def test_headless_sync_invoke_calls_interrupt() -> None: + """Sync `invoke` must work (StructuredTool previously had no sync path).""" + ht = tool( + name="sync_interrupt", + description="d", + args_schema=_MessageArgs, + ) + with patch("langchain.tools.headless.interrupt") as mock_interrupt: + mock_interrupt.return_value = "ok" + result = ht.invoke( + { + "type": "tool_call", + "name": "sync_interrupt", + "id": "cid-9", + "args": {"message": "sync"}, + } + ) + mock_interrupt.assert_called_once() + payload = mock_interrupt.call_args[0][0] + assert payload["tool_call"]["id"] == "cid-9" + assert getattr(result, "content", result) == "ok" + + def test_headless_dict_schema_has_metadata() -> None: schema: dict[str, Any] = { "type": "object", "properties": {"q": {"type": "string"}}, "required": ["q"], } - ht = create_headless_tool( + ht = tool( name="dict_tool", description="Uses JSON schema.", args_schema=schema, @@ -97,7 +115,7 @@ def test_headless_dict_schema_has_metadata() -> None: def test_invoke_without_graph_context_errors() -> None: - ht = create_headless_tool( + ht = tool( name="t", description="d", args_schema=_MessageArgs, diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_imports.py b/libs/langchain_v1/tests/unit_tests/tools/test_imports.py index e26c83cae5c..f3895906956 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_imports.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_imports.py @@ -10,7 +10,6 @@ EXPECTED_ALL = { "InjectedToolCallId", "ToolException", "ToolRuntime", - "create_headless_tool", "tool", }