mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-12 23:42:51 +00:00
cr
This commit is contained in:
@@ -7,12 +7,7 @@ from langchain_core.tools import (
|
||||
ToolException,
|
||||
)
|
||||
|
||||
from langchain.tools.headless import (
|
||||
HEADLESS_TOOL_METADATA_KEY,
|
||||
HeadlessTool,
|
||||
create_headless_tool,
|
||||
tool,
|
||||
)
|
||||
from langchain.tools.headless import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool
|
||||
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolRuntime
|
||||
|
||||
__all__ = [
|
||||
@@ -25,6 +20,5 @@ __all__ = [
|
||||
"InjectedToolCallId",
|
||||
"ToolException",
|
||||
"ToolRuntime",
|
||||
"create_headless_tool",
|
||||
"tool",
|
||||
]
|
||||
|
||||
@@ -20,7 +20,10 @@ from langchain_core.utils.pydantic import is_basemodel_subclass, is_pydantic_v2_
|
||||
from langgraph.types import interrupt
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
# Metadata flag for clients that need to detect headless tools (e.g. SDKs).
|
||||
# Metadata on the tool definition for introspection (e.g. listing tools, bind_tools).
|
||||
# This does not appear as LLM token stream chunks. When a headless tool runs inside a
|
||||
# LangGraph graph, the interrupt value (see `_headless_interrupt_payload`) is what
|
||||
# surfaces to clients during streamed graph execution (interrupt events).
|
||||
HEADLESS_TOOL_METADATA_KEY = "headless_tool"
|
||||
|
||||
|
||||
@@ -51,22 +54,37 @@ def _args_schema_with_injected_tool_call_id(
|
||||
)
|
||||
|
||||
|
||||
def _headless_interrupt_payload(tool_name: str, **kwargs: Any) -> Any:
|
||||
"""Build the LangGraph interrupt value for a headless tool call."""
|
||||
tool_call_id = kwargs.pop("tool_call_id", None)
|
||||
return interrupt(
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call": {
|
||||
"id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"args": kwargs,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _make_headless_sync(tool_name: str) -> Callable[..., Any]:
|
||||
def _headless_sync(
|
||||
_config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return _headless_interrupt_payload(tool_name, **kwargs)
|
||||
|
||||
return _headless_sync
|
||||
|
||||
|
||||
def _make_headless_coroutine(tool_name: str) -> Callable[..., Awaitable[Any]]:
|
||||
async def _headless_coroutine(
|
||||
_config: RunnableConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
tool_call_id = kwargs.pop("tool_call_id", None)
|
||||
return interrupt(
|
||||
{
|
||||
"type": "tool",
|
||||
"tool_call": {
|
||||
"id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"args": kwargs,
|
||||
},
|
||||
}
|
||||
)
|
||||
return _headless_interrupt_payload(tool_name, **kwargs)
|
||||
|
||||
return _headless_coroutine
|
||||
|
||||
@@ -75,7 +93,7 @@ class HeadlessTool(StructuredTool):
|
||||
"""Structured tool that interrupts instead of executing locally."""
|
||||
|
||||
|
||||
def create_headless_tool(
|
||||
def _create_headless_tool(
|
||||
*,
|
||||
name: str,
|
||||
description: str,
|
||||
@@ -84,18 +102,7 @@ def create_headless_tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
extras: dict[str, Any] | None = None,
|
||||
) -> HeadlessTool:
|
||||
"""Create a headless tool from a name, description, and argument schema.
|
||||
|
||||
Args:
|
||||
name: Tool name exposed to the model.
|
||||
description: Tool description exposed to the model.
|
||||
args_schema: Pydantic model or JSON-schema dict for arguments.
|
||||
return_direct: Whether the tool result should end the agent turn.
|
||||
response_format: Same as `StructuredTool`.
|
||||
extras: Optional provider-specific extras merged into tool extras.
|
||||
|
||||
Returns:
|
||||
A `HeadlessTool` whose coroutine raises a LangGraph interrupt when run.
|
||||
"""Instantiate a headless tool. Prefer the public `tool()` overload for new code.
|
||||
|
||||
Raises:
|
||||
TypeError: If `args_schema` is not a Pydantic model or dict.
|
||||
@@ -112,10 +119,11 @@ def create_headless_tool(
|
||||
raise TypeError(msg)
|
||||
|
||||
metadata = {HEADLESS_TOOL_METADATA_KEY: True}
|
||||
sync_fn = _make_headless_sync(name)
|
||||
coroutine = _make_headless_coroutine(name)
|
||||
return HeadlessTool(
|
||||
name=name,
|
||||
func=None,
|
||||
func=sync_fn,
|
||||
coroutine=coroutine,
|
||||
description=description,
|
||||
args_schema=schema_for_tool,
|
||||
@@ -214,9 +222,10 @@ def tool(
|
||||
) -> BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool] | HeadlessTool:
|
||||
"""Create a tool, including headless (interrupting) tools.
|
||||
|
||||
If called with keyword-only `name`, `description`, and `args_schema` and no
|
||||
implementation callable, returns a `HeadlessTool` that triggers a LangGraph
|
||||
`interrupt` when executed. Otherwise delegates to `langchain_core.tools.tool`.
|
||||
This is the supported entry point for headless tools: use keyword-only
|
||||
`name`, `description`, and `args_schema` with no implementation callable to get
|
||||
a `HeadlessTool` that calls LangGraph `interrupt` on both sync `invoke` and
|
||||
async `ainvoke`. Otherwise delegates to `langchain_core.tools.tool`.
|
||||
|
||||
Args:
|
||||
name_or_callable: Passed through to core `tool` when not using headless mode.
|
||||
@@ -242,7 +251,7 @@ def tool(
|
||||
and description is not None
|
||||
and args_schema is not None
|
||||
):
|
||||
return create_headless_tool(
|
||||
return _create_headless_tool(
|
||||
name=name,
|
||||
description=description,
|
||||
args_schema=args_schema,
|
||||
|
||||
@@ -9,20 +9,15 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.tools import (
|
||||
HEADLESS_TOOL_METADATA_KEY,
|
||||
HeadlessTool,
|
||||
create_headless_tool,
|
||||
tool,
|
||||
)
|
||||
from langchain.tools import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool
|
||||
|
||||
|
||||
class _MessageArgs(BaseModel):
|
||||
message: str = Field(..., description="A message.")
|
||||
|
||||
|
||||
def test_create_headless_tool_properties() -> None:
|
||||
t = create_headless_tool(
|
||||
def test_headless_tool_properties() -> None:
|
||||
t = tool(
|
||||
name="test_tool",
|
||||
description="A test headless tool.",
|
||||
args_schema=_MessageArgs,
|
||||
@@ -55,7 +50,7 @@ def test_tool_normal_still_returns_structured_tool() -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_headless_coroutine_calls_interrupt() -> None:
|
||||
ht = create_headless_tool(
|
||||
ht = tool(
|
||||
name="interrupt_me",
|
||||
description="d",
|
||||
args_schema=_MessageArgs,
|
||||
@@ -81,13 +76,36 @@ async def test_headless_coroutine_calls_interrupt() -> None:
|
||||
assert getattr(result, "content", result) == "resumed"
|
||||
|
||||
|
||||
def test_headless_sync_invoke_calls_interrupt() -> None:
|
||||
"""Sync `invoke` must work (StructuredTool previously had no sync path)."""
|
||||
ht = tool(
|
||||
name="sync_interrupt",
|
||||
description="d",
|
||||
args_schema=_MessageArgs,
|
||||
)
|
||||
with patch("langchain.tools.headless.interrupt") as mock_interrupt:
|
||||
mock_interrupt.return_value = "ok"
|
||||
result = ht.invoke(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"name": "sync_interrupt",
|
||||
"id": "cid-9",
|
||||
"args": {"message": "sync"},
|
||||
}
|
||||
)
|
||||
mock_interrupt.assert_called_once()
|
||||
payload = mock_interrupt.call_args[0][0]
|
||||
assert payload["tool_call"]["id"] == "cid-9"
|
||||
assert getattr(result, "content", result) == "ok"
|
||||
|
||||
|
||||
def test_headless_dict_schema_has_metadata() -> None:
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
}
|
||||
ht = create_headless_tool(
|
||||
ht = tool(
|
||||
name="dict_tool",
|
||||
description="Uses JSON schema.",
|
||||
args_schema=schema,
|
||||
@@ -97,7 +115,7 @@ def test_headless_dict_schema_has_metadata() -> None:
|
||||
|
||||
|
||||
def test_invoke_without_graph_context_errors() -> None:
|
||||
ht = create_headless_tool(
|
||||
ht = tool(
|
||||
name="t",
|
||||
description="d",
|
||||
args_schema=_MessageArgs,
|
||||
|
||||
@@ -10,7 +10,6 @@ EXPECTED_ALL = {
|
||||
"InjectedToolCallId",
|
||||
"ToolException",
|
||||
"ToolRuntime",
|
||||
"create_headless_tool",
|
||||
"tool",
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user