From d9b22fe8926ad290b663cba339d30252849f8a03 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Thu, 19 Mar 2026 16:24:04 -0700 Subject: [PATCH] feat(langchain): support for headless tools --- libs/langchain_v1/langchain/tools/__init__.py | 10 +- libs/langchain_v1/langchain/tools/headless.py | 269 ++++++++++++++++++ .../tests/unit_tests/tools/test_headless.py | 115 ++++++++ .../tests/unit_tests/tools/test_imports.py | 3 + 4 files changed, 396 insertions(+), 1 deletion(-) create mode 100644 libs/langchain_v1/langchain/tools/headless.py create mode 100644 libs/langchain_v1/tests/unit_tests/tools/test_headless.py diff --git a/libs/langchain_v1/langchain/tools/__init__.py b/libs/langchain_v1/langchain/tools/__init__.py index 903ed78a86d..1ef11cc5622 100644 --- a/libs/langchain_v1/langchain/tools/__init__.py +++ b/libs/langchain_v1/langchain/tools/__init__.py @@ -5,18 +5,26 @@ from langchain_core.tools import ( InjectedToolArg, InjectedToolCallId, ToolException, - tool, ) +from langchain.tools.headless import ( + HEADLESS_TOOL_METADATA_KEY, + HeadlessTool, + create_headless_tool, + tool, +) from langchain.tools.tool_node import InjectedState, InjectedStore, ToolRuntime __all__ = [ + "HEADLESS_TOOL_METADATA_KEY", "BaseTool", + "HeadlessTool", "InjectedState", "InjectedStore", "InjectedToolArg", "InjectedToolCallId", "ToolException", "ToolRuntime", + "create_headless_tool", "tool", ] diff --git a/libs/langchain_v1/langchain/tools/headless.py b/libs/langchain_v1/langchain/tools/headless.py new file mode 100644 index 00000000000..836012d1842 --- /dev/null +++ b/libs/langchain_v1/langchain/tools/headless.py @@ -0,0 +1,269 @@ +"""Headless tools: schema-only tools that interrupt for out-of-process execution. + +Mirrors the LangChain.js `tool` overload from +https://github.com/langchain-ai/langchainjs/pull/10430 — tools defined with +`name`, `description`, and `args_schema` only. When invoked inside a LangGraph +agent, execution pauses with an interrupt payload so a client can run the +implementation and resume. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable # noqa: TC003 +from typing import Annotated, Any, Literal, cast, overload + +from langchain_core.runnables import Runnable, RunnableConfig # noqa: TC002 +from langchain_core.tools import tool as core_tool +from langchain_core.tools.base import ArgsSchema, BaseTool, InjectedToolCallId +from langchain_core.tools.structured import StructuredTool +from langchain_core.utils.pydantic import is_basemodel_subclass, is_pydantic_v2_subclass +from langgraph.types import interrupt +from pydantic import BaseModel, create_model + +# Metadata flag for clients that need to detect headless tools (e.g. SDKs). +HEADLESS_TOOL_METADATA_KEY = "headless_tool" + + +def _args_schema_with_injected_tool_call_id( + tool_name: str, + args_schema: type[BaseModel], +) -> type[BaseModel]: + """Extend a user args model with an injected `tool_call_id` field. + + The field is stripped from the model-facing tool schema but populated at + invocation time so interrupt payloads can include the tool call id. + + Args: + tool_name: Base name for the generated schema type. + args_schema: Original Pydantic model for tool arguments. + + Returns: + A new model type including `tool_call_id` injection. + """ + model_name = f"{tool_name}HeadlessInput" + return create_model( + model_name, + __base__=args_schema, + tool_call_id=( + Annotated[str | None, InjectedToolCallId], + None, + ), + ) + + +def _make_headless_coroutine(tool_name: str) -> Callable[..., Awaitable[Any]]: + async def _headless_coroutine( + _config: RunnableConfig, + **kwargs: Any, + ) -> Any: + tool_call_id = kwargs.pop("tool_call_id", None) + return interrupt( + { + "type": "tool", + "tool_call": { + "id": tool_call_id, + "name": tool_name, + "args": kwargs, + }, + } + ) + + return _headless_coroutine + + +class HeadlessTool(StructuredTool): + """Structured tool that interrupts instead of executing locally.""" + + +def create_headless_tool( + *, + name: str, + description: str, + args_schema: ArgsSchema, + return_direct: bool = False, + response_format: Literal["content", "content_and_artifact"] = "content", + extras: dict[str, Any] | None = None, +) -> HeadlessTool: + """Create a headless tool from a name, description, and argument schema. + + Args: + name: Tool name exposed to the model. + description: Tool description exposed to the model. + args_schema: Pydantic model or JSON-schema dict for arguments. + return_direct: Whether the tool result should end the agent turn. + response_format: Same as `StructuredTool`. + extras: Optional provider-specific extras merged into tool extras. + + Returns: + A `HeadlessTool` whose coroutine raises a LangGraph interrupt when run. + + Raises: + TypeError: If `args_schema` is not a Pydantic model or dict. + """ + if isinstance(args_schema, dict): + schema_for_tool: ArgsSchema = args_schema + elif is_basemodel_subclass(args_schema): + if is_pydantic_v2_subclass(args_schema): + schema_for_tool = _args_schema_with_injected_tool_call_id(name, args_schema) + else: + schema_for_tool = args_schema + else: + msg = "args_schema must be a Pydantic BaseModel subclass or a dict schema." + raise TypeError(msg) + + metadata = {HEADLESS_TOOL_METADATA_KEY: True} + coroutine = _make_headless_coroutine(name) + return HeadlessTool( + name=name, + func=None, + coroutine=coroutine, + description=description, + args_schema=schema_for_tool, + return_direct=return_direct, + response_format=response_format, + metadata=metadata, + extras=extras, + ) + + +@overload +def tool( + *, + name: str, + description: str, + args_schema: ArgsSchema, + return_direct: bool = False, + response_format: Literal["content", "content_and_artifact"] = "content", + extras: dict[str, Any] | None = None, +) -> HeadlessTool: ... + + +@overload +def tool( + name_or_callable: str, + runnable: Runnable[Any, Any], + *, + description: str | None = None, + return_direct: bool = False, + args_schema: ArgsSchema | None = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, + extras: dict[str, Any] | None = None, +) -> BaseTool: ... + + +@overload +def tool( + name_or_callable: Callable[..., Any], + *, + description: str | None = None, + return_direct: bool = False, + args_schema: ArgsSchema | None = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, + extras: dict[str, Any] | None = None, +) -> BaseTool: ... + + +@overload +def tool( + name_or_callable: str, + *, + description: str | None = None, + return_direct: bool = False, + args_schema: ArgsSchema | None = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, + extras: dict[str, Any] | None = None, +) -> Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]: ... + + +@overload +def tool( + *, + description: str | None = None, + return_direct: bool = False, + args_schema: ArgsSchema | None = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, + extras: dict[str, Any] | None = None, +) -> Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]: ... + + +def tool( + name_or_callable: str | Callable[..., Any] | None = None, + runnable: Runnable[Any, Any] | None = None, + *args: Any, + name: str | None = None, + description: str | None = None, + return_direct: bool = False, + args_schema: ArgsSchema | None = None, + infer_schema: bool = True, + response_format: Literal["content", "content_and_artifact"] = "content", + parse_docstring: bool = False, + error_on_invalid_docstring: bool = True, + extras: dict[str, Any] | None = None, +) -> BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool] | HeadlessTool: + """Create a tool, including headless (interrupting) tools. + + If called with keyword-only `name`, `description`, and `args_schema` and no + implementation callable, returns a `HeadlessTool` that triggers a LangGraph + `interrupt` when executed. Otherwise delegates to `langchain_core.tools.tool`. + + Args: + name_or_callable: Passed through to core `tool` when not using headless mode. + runnable: Passed through to core `tool`. + name: Tool name (headless overload only). + description: Tool description. + args_schema: Argument schema (`BaseModel` or JSON-schema dict). + return_direct: Whether to return directly from the tool node. + infer_schema: Whether to infer schema from a decorated function (core `tool`). + response_format: Core tool response format. + parse_docstring: Core `tool` docstring parsing flag. + error_on_invalid_docstring: Core `tool` flag. + extras: Optional provider-specific extras. + + Returns: + A `HeadlessTool`, a `BaseTool`, or a decorator factory from core `tool`. + """ + if ( + len(args) == 0 + and name_or_callable is None + and runnable is None + and name is not None + and description is not None + and args_schema is not None + ): + return create_headless_tool( + name=name, + description=description, + args_schema=args_schema, + return_direct=return_direct, + response_format=response_format, + extras=extras, + ) + delegated = core_tool( + cast("Any", name_or_callable), + cast("Any", runnable), + *args, + description=description, + return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, + response_format=response_format, + parse_docstring=parse_docstring, + error_on_invalid_docstring=error_on_invalid_docstring, + extras=extras, + ) + return cast( + "BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]", + delegated, + ) diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_headless.py b/libs/langchain_v1/tests/unit_tests/tools/test_headless.py new file mode 100644 index 00000000000..98fb372f1f8 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/tools/test_headless.py @@ -0,0 +1,115 @@ +"""Tests for headless (interrupting) tools.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import patch + +import pytest +from pydantic import BaseModel, Field + +from langchain.tools import ( + HEADLESS_TOOL_METADATA_KEY, + HeadlessTool, + create_headless_tool, + tool, +) + + +class _MessageArgs(BaseModel): + message: str = Field(..., description="A message.") + + +def test_create_headless_tool_properties() -> None: + t = create_headless_tool( + name="test_tool", + description="A test headless tool.", + args_schema=_MessageArgs, + ) + assert isinstance(t, HeadlessTool) + assert t.name == "test_tool" + assert t.description == "A test headless tool." + assert t.metadata == {HEADLESS_TOOL_METADATA_KEY: True} + + +def test_tool_headless_overload() -> None: + t = tool( + name="from_overload", + description="via unified tool()", + args_schema=_MessageArgs, + ) + assert isinstance(t, HeadlessTool) + assert t.name == "from_overload" + + +def test_tool_normal_still_returns_structured_tool() -> None: + def get_weather(city: str) -> str: + """Return a fake forecast for the city.""" + return f"sunny in {city}" + + w = tool(get_weather) + assert not isinstance(w, HeadlessTool) + assert w.name == "get_weather" + + +@pytest.mark.asyncio +async def test_headless_coroutine_calls_interrupt() -> None: + ht = create_headless_tool( + name="interrupt_me", + description="d", + args_schema=_MessageArgs, + ) + with patch("langchain.tools.headless.interrupt") as mock_interrupt: + mock_interrupt.return_value = "resumed" + + result = await ht.ainvoke( + { + "type": "tool_call", + "name": "interrupt_me", + "id": "call-1", + "args": {"message": "hi"}, + } + ) + + mock_interrupt.assert_called_once() + payload = mock_interrupt.call_args[0][0] + assert payload["type"] == "tool" + assert payload["tool_call"]["id"] == "call-1" + assert payload["tool_call"]["name"] == "interrupt_me" + assert payload["tool_call"]["args"] == {"message": "hi"} + assert getattr(result, "content", result) == "resumed" + + +def test_headless_dict_schema_has_metadata() -> None: + schema: dict[str, Any] = { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + } + ht = create_headless_tool( + name="dict_tool", + description="Uses JSON schema.", + args_schema=schema, + ) + assert ht.metadata == {HEADLESS_TOOL_METADATA_KEY: True} + assert "q" in ht.args + + +def test_invoke_without_graph_context_errors() -> None: + ht = create_headless_tool( + name="t", + description="d", + args_schema=_MessageArgs, + ) + with pytest.raises((RuntimeError, KeyError)): + asyncio.run( + ht.ainvoke( + { + "type": "tool_call", + "name": "t", + "id": "x", + "args": {"message": "m"}, + } + ) + ) diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_imports.py b/libs/langchain_v1/tests/unit_tests/tools/test_imports.py index 17b66c6f771..e26c83cae5c 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_imports.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_imports.py @@ -1,13 +1,16 @@ from langchain import tools EXPECTED_ALL = { + "HEADLESS_TOOL_METADATA_KEY", "BaseTool", + "HeadlessTool", "InjectedState", "InjectedStore", "InjectedToolArg", "InjectedToolCallId", "ToolException", "ToolRuntime", + "create_headless_tool", "tool", }