feat(langchain): support for headless tools

This commit is contained in:
Christian Bromann
2026-03-19 16:24:04 -07:00
parent 3b4cd75a0c
commit d9b22fe892
4 changed files with 396 additions and 1 deletions

View File

@@ -5,18 +5,26 @@ from langchain_core.tools import (
InjectedToolArg,
InjectedToolCallId,
ToolException,
tool,
)
from langchain.tools.headless import (
HEADLESS_TOOL_METADATA_KEY,
HeadlessTool,
create_headless_tool,
tool,
)
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolRuntime
__all__ = [
"HEADLESS_TOOL_METADATA_KEY",
"BaseTool",
"HeadlessTool",
"InjectedState",
"InjectedStore",
"InjectedToolArg",
"InjectedToolCallId",
"ToolException",
"ToolRuntime",
"create_headless_tool",
"tool",
]

View File

@@ -0,0 +1,269 @@
"""Headless tools: schema-only tools that interrupt for out-of-process execution.
Mirrors the LangChain.js `tool` overload from
https://github.com/langchain-ai/langchainjs/pull/10430 — tools defined with
`name`, `description`, and `args_schema` only. When invoked inside a LangGraph
agent, execution pauses with an interrupt payload so a client can run the
implementation and resume.
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable # noqa: TC003
from typing import Annotated, Any, Literal, cast, overload
from langchain_core.runnables import Runnable, RunnableConfig # noqa: TC002
from langchain_core.tools import tool as core_tool
from langchain_core.tools.base import ArgsSchema, BaseTool, InjectedToolCallId
from langchain_core.tools.structured import StructuredTool
from langchain_core.utils.pydantic import is_basemodel_subclass, is_pydantic_v2_subclass
from langgraph.types import interrupt
from pydantic import BaseModel, create_model
# Metadata flag for clients that need to detect headless tools (e.g. SDKs).
HEADLESS_TOOL_METADATA_KEY = "headless_tool"
def _args_schema_with_injected_tool_call_id(
tool_name: str,
args_schema: type[BaseModel],
) -> type[BaseModel]:
"""Extend a user args model with an injected `tool_call_id` field.
The field is stripped from the model-facing tool schema but populated at
invocation time so interrupt payloads can include the tool call id.
Args:
tool_name: Base name for the generated schema type.
args_schema: Original Pydantic model for tool arguments.
Returns:
A new model type including `tool_call_id` injection.
"""
model_name = f"{tool_name}HeadlessInput"
return create_model(
model_name,
__base__=args_schema,
tool_call_id=(
Annotated[str | None, InjectedToolCallId],
None,
),
)
def _make_headless_coroutine(tool_name: str) -> Callable[..., Awaitable[Any]]:
async def _headless_coroutine(
_config: RunnableConfig,
**kwargs: Any,
) -> Any:
tool_call_id = kwargs.pop("tool_call_id", None)
return interrupt(
{
"type": "tool",
"tool_call": {
"id": tool_call_id,
"name": tool_name,
"args": kwargs,
},
}
)
return _headless_coroutine
class HeadlessTool(StructuredTool):
"""Structured tool that interrupts instead of executing locally."""
def create_headless_tool(
*,
name: str,
description: str,
args_schema: ArgsSchema,
return_direct: bool = False,
response_format: Literal["content", "content_and_artifact"] = "content",
extras: dict[str, Any] | None = None,
) -> HeadlessTool:
"""Create a headless tool from a name, description, and argument schema.
Args:
name: Tool name exposed to the model.
description: Tool description exposed to the model.
args_schema: Pydantic model or JSON-schema dict for arguments.
return_direct: Whether the tool result should end the agent turn.
response_format: Same as `StructuredTool`.
extras: Optional provider-specific extras merged into tool extras.
Returns:
A `HeadlessTool` whose coroutine raises a LangGraph interrupt when run.
Raises:
TypeError: If `args_schema` is not a Pydantic model or dict.
"""
if isinstance(args_schema, dict):
schema_for_tool: ArgsSchema = args_schema
elif is_basemodel_subclass(args_schema):
if is_pydantic_v2_subclass(args_schema):
schema_for_tool = _args_schema_with_injected_tool_call_id(name, args_schema)
else:
schema_for_tool = args_schema
else:
msg = "args_schema must be a Pydantic BaseModel subclass or a dict schema."
raise TypeError(msg)
metadata = {HEADLESS_TOOL_METADATA_KEY: True}
coroutine = _make_headless_coroutine(name)
return HeadlessTool(
name=name,
func=None,
coroutine=coroutine,
description=description,
args_schema=schema_for_tool,
return_direct=return_direct,
response_format=response_format,
metadata=metadata,
extras=extras,
)
@overload
def tool(
*,
name: str,
description: str,
args_schema: ArgsSchema,
return_direct: bool = False,
response_format: Literal["content", "content_and_artifact"] = "content",
extras: dict[str, Any] | None = None,
) -> HeadlessTool: ...
@overload
def tool(
name_or_callable: str,
runnable: Runnable[Any, Any],
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: Callable[..., Any],
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: str,
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]: ...
@overload
def tool(
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]: ...
def tool(
name_or_callable: str | Callable[..., Any] | None = None,
runnable: Runnable[Any, Any] | None = None,
*args: Any,
name: str | None = None,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool] | HeadlessTool:
"""Create a tool, including headless (interrupting) tools.
If called with keyword-only `name`, `description`, and `args_schema` and no
implementation callable, returns a `HeadlessTool` that triggers a LangGraph
`interrupt` when executed. Otherwise delegates to `langchain_core.tools.tool`.
Args:
name_or_callable: Passed through to core `tool` when not using headless mode.
runnable: Passed through to core `tool`.
name: Tool name (headless overload only).
description: Tool description.
args_schema: Argument schema (`BaseModel` or JSON-schema dict).
return_direct: Whether to return directly from the tool node.
infer_schema: Whether to infer schema from a decorated function (core `tool`).
response_format: Core tool response format.
parse_docstring: Core `tool` docstring parsing flag.
error_on_invalid_docstring: Core `tool` flag.
extras: Optional provider-specific extras.
Returns:
A `HeadlessTool`, a `BaseTool`, or a decorator factory from core `tool`.
"""
if (
len(args) == 0
and name_or_callable is None
and runnable is None
and name is not None
and description is not None
and args_schema is not None
):
return create_headless_tool(
name=name,
description=description,
args_schema=args_schema,
return_direct=return_direct,
response_format=response_format,
extras=extras,
)
delegated = core_tool(
cast("Any", name_or_callable),
cast("Any", runnable),
*args,
description=description,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
extras=extras,
)
return cast(
"BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]",
delegated,
)

View File

@@ -0,0 +1,115 @@
"""Tests for headless (interrupting) tools."""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import patch
import pytest
from pydantic import BaseModel, Field
from langchain.tools import (
HEADLESS_TOOL_METADATA_KEY,
HeadlessTool,
create_headless_tool,
tool,
)
class _MessageArgs(BaseModel):
message: str = Field(..., description="A message.")
def test_create_headless_tool_properties() -> None:
t = create_headless_tool(
name="test_tool",
description="A test headless tool.",
args_schema=_MessageArgs,
)
assert isinstance(t, HeadlessTool)
assert t.name == "test_tool"
assert t.description == "A test headless tool."
assert t.metadata == {HEADLESS_TOOL_METADATA_KEY: True}
def test_tool_headless_overload() -> None:
t = tool(
name="from_overload",
description="via unified tool()",
args_schema=_MessageArgs,
)
assert isinstance(t, HeadlessTool)
assert t.name == "from_overload"
def test_tool_normal_still_returns_structured_tool() -> None:
def get_weather(city: str) -> str:
"""Return a fake forecast for the city."""
return f"sunny in {city}"
w = tool(get_weather)
assert not isinstance(w, HeadlessTool)
assert w.name == "get_weather"
@pytest.mark.asyncio
async def test_headless_coroutine_calls_interrupt() -> None:
ht = create_headless_tool(
name="interrupt_me",
description="d",
args_schema=_MessageArgs,
)
with patch("langchain.tools.headless.interrupt") as mock_interrupt:
mock_interrupt.return_value = "resumed"
result = await ht.ainvoke(
{
"type": "tool_call",
"name": "interrupt_me",
"id": "call-1",
"args": {"message": "hi"},
}
)
mock_interrupt.assert_called_once()
payload = mock_interrupt.call_args[0][0]
assert payload["type"] == "tool"
assert payload["tool_call"]["id"] == "call-1"
assert payload["tool_call"]["name"] == "interrupt_me"
assert payload["tool_call"]["args"] == {"message": "hi"}
assert getattr(result, "content", result) == "resumed"
def test_headless_dict_schema_has_metadata() -> None:
schema: dict[str, Any] = {
"type": "object",
"properties": {"q": {"type": "string"}},
"required": ["q"],
}
ht = create_headless_tool(
name="dict_tool",
description="Uses JSON schema.",
args_schema=schema,
)
assert ht.metadata == {HEADLESS_TOOL_METADATA_KEY: True}
assert "q" in ht.args
def test_invoke_without_graph_context_errors() -> None:
ht = create_headless_tool(
name="t",
description="d",
args_schema=_MessageArgs,
)
with pytest.raises((RuntimeError, KeyError)):
asyncio.run(
ht.ainvoke(
{
"type": "tool_call",
"name": "t",
"id": "x",
"args": {"message": "m"},
}
)
)

View File

@@ -1,13 +1,16 @@
from langchain import tools
EXPECTED_ALL = {
"HEADLESS_TOOL_METADATA_KEY",
"BaseTool",
"HeadlessTool",
"InjectedState",
"InjectedStore",
"InjectedToolArg",
"InjectedToolCallId",
"ToolException",
"ToolRuntime",
"create_headless_tool",
"tool",
}