Compare commits

...

3 Commits

Author SHA1 Message Date
Christian Bromann
5973fea296 add demo 2026-04-03 23:16:22 -07:00
Christian Bromann
a21969543c cr 2026-03-31 10:11:46 -07:00
Christian Bromann
d9b22fe892 feat(langchain): support for headless tools 2026-03-31 10:07:12 -07:00
21 changed files with 4957 additions and 1 deletions

View File

@@ -0,0 +1,3 @@
OPENAI_API_KEY=sk-...
# Optional override (default: openai:gpt-4o-mini)
# OPENAI_MODEL=openai:gpt-4o-mini

View File

@@ -0,0 +1,7 @@
.env
.venv/
__pycache__/
*.egg-info/
.langgraph_api/
frontend/node_modules/
frontend/dist/

View File

@@ -0,0 +1,87 @@
# Headless tool + LangGraph dev server (E2E)
This example runs a **local LangGraph API** (`langgraph dev`) with a graph built via
[`create_agent`](https://docs.langchain.com/oss/python/langchain/agents) and a **headless
tool** from the branch (`langchain.tools.tool` with `name`, `description`, and
`args_schema` only). The tool calls LangGraph `interrupt()` so you can handle execution in
a browser or other client and **resume** with the result.
## Prerequisites
- [uv](https://docs.astral.sh/uv/)
- An `OPENAI_API_KEY` in `.env` (the demo model defaults to `openai:gpt-4o-mini`)
## 1. Configure environment
```bash
cd examples/headless-langgraph-e2e
cp .env.example .env
# Edit .env and set OPENAI_API_KEY
```
## 2. Install dependencies
```bash
uv sync --group dev
```
## 3. Start the LangGraph dev server
Default API URL is `http://127.0.0.1:2024` (CORS allows `*` by default).
```bash
uv run langgraph dev --config langgraph.json --no-browser
```
The dev server listens on **port 2024** by default (see the banner for the exact API URL).
### Port 2024 already in use
Another process (often a previous `langgraph dev`) is bound to 2024. Either stop it or use a free port:
```bash
# See what is using 2024 (macOS / Linux)
lsof -i :2024
```
Then quit that process, or start on another port:
```bash
uv run langgraph dev --config langgraph.json --no-browser --port 2025
```
Set the **LangGraph API base URL** in the frontend to match (for example `http://127.0.0.1:2025`).
## 4. Run the React frontend (Vite + `useStream`)
The UI mirrors the headless-tools pattern from `ui-playground` (`tool({ name, description, schema })` +
`.implement(...)` passed to `useStream({ tools: [...] })`).
In another shell:
```bash
cd examples/headless-langgraph-e2e/frontend
npm install
npm run dev
```
Open `http://127.0.0.1:8765`. The **LangGraph API base URL** field defaults to
`http://127.0.0.1:2024`; change it if you used `--port` with a different value.
Send a message that asks to open a URL or for your location; the model should call the
matching headless tool, the client runs the implementation in `src/tools.ts`, and streaming
continues after the tool result is applied.
## Notes
- **Checkpointer:** For a plain Python script you normally pass `checkpointer=InMemorySaver()`
to `create_agent` so `interrupt()` / headless tools persist. The **LangGraph dev server**
injects its own persistence and **rejects** a custom checkpointer on the graph, so this
example omits it.
- **Studio:** The server banner prints a link to open **LangSmith Studio** against your
local API (useful for debugging the same graph).
## Graph id
The graph is registered as **`agent`** in `langgraph.json` (this is the assistant / graph
id passed to the SDK).

View File

@@ -0,0 +1,65 @@
"""LangGraph dev graph: `create_agent` with a headless (interrupting) tool.
Headless tools (`browser_navigate`, `geolocation_get`) have no local implementation;
they call LangGraph `interrupt()` so the client can execute them and resume with a result.
Run the API: ``uv run langgraph dev`` from this directory (see README).
"""
from __future__ import annotations
import os
from langchain.agents import create_agent
from langchain.tools import tool
from pydantic import BaseModel, Field
# Model id for init_chat_model (requires OPENAI_API_KEY in .env for OpenAI models).
_model = os.environ.get("OPENAI_MODEL", "openai:gpt-4o-mini")
class BrowserNavigateArgs(BaseModel):
"""Arguments for the headless browser navigation tool."""
url: str = Field(..., description="URL to open in the headless browser.")
class GeolocationGetArgs(BaseModel):
"""Arguments for the headless geolocation tool."""
high_accuracy: bool | None = Field(
default=None,
description="If true, request high-accuracy GPS from the browser when supported.",
)
browser_navigate = tool(
name="browser_navigate",
description=(
"Navigate a headless browser to a URL. Execution is delegated to the client: "
"the graph pauses until the client resumes with the page result."
),
args_schema=BrowserNavigateArgs,
)
geolocation_get = tool(
name="geolocation_get",
description=(
"Get the user's current GPS coordinates using the browser Geolocation API. "
"The client shows their position on a map (OpenStreetMap). "
"The browser may prompt for permission the first time."
),
args_schema=GeolocationGetArgs,
)
graph = create_agent(
_model,
tools=[browser_navigate, geolocation_get],
system_prompt=(
"You are a helpful assistant. When the user asks to open, visit, or browse "
"a URL, call browser_navigate with that URL. When they ask where they are, "
"their location, or to show them on a map, call geolocation_get. "
"After the client returns out-of-process results, reply briefly using that information."
),
# No custom checkpointer: LangGraph API (`langgraph dev`) supplies persistence.
)

View File

@@ -0,0 +1,12 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Headless tool demo</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,24 @@
{
"name": "headless-langgraph-e2e-frontend",
"private": true,
"type": "module",
"scripts": {
"dev": "vite --port 8765",
"build": "vite build",
"preview": "vite preview --port 8765"
},
"dependencies": {
"@langchain/react": "^0.3.2",
"langchain": "^1.3.0",
"react": "^19",
"react-dom": "^19",
"zod": "^4.3.6"
},
"devDependencies": {
"@types/react": "^19",
"@types/react-dom": "^19",
"@vitejs/plugin-react": "^6",
"typescript": "^5.8",
"vite": "^8"
}
}

View File

@@ -0,0 +1,243 @@
import { useCallback, useMemo, useState, type ReactNode } from "react";
import { AIMessage, HumanMessage } from "langchain";
import {
useStream,
type DefaultToolCall,
type ToolCallWithResult,
} from "@langchain/react";
import { LocationMap } from "./LocationMap";
import { browserNavigateClient, geolocationGetClient } from "./tools";
const ASSISTANT_ID = "agent";
const PRESETS = [
"Please open https://example.com in the browser.",
"Visit https://langchain.com and tell me the page title if you can.",
"Where am I right now? Show me on a map.",
];
function parseResultContent(result: ToolCallWithResult<DefaultToolCall>["result"]): unknown {
if (!result) return undefined;
const c = result.content;
const raw = typeof c === "string" ? c : JSON.stringify(c);
try {
return JSON.parse(raw) as unknown;
} catch {
return raw;
}
}
function HeadlessToolCallCard({
toolCall,
}: {
toolCall: ToolCallWithResult<DefaultToolCall>;
}) {
const { call, result, state } = toolCall;
const pending = state === "pending";
const parsed = parseResultContent(result);
const title =
call.name === "geolocation_get"
? "geolocation_get"
: call.name === "browser_navigate"
? "browser_navigate"
: call.name;
const icon =
call.name === "geolocation_get" ? "📍" : call.name === "browser_navigate" ? "🌐" : "🔧";
let map: ReactNode = null;
if (
call.name === "geolocation_get" &&
parsed &&
typeof parsed === "object" &&
parsed !== null &&
"success" in parsed &&
(parsed as { success?: boolean }).success === true &&
"latitude" in parsed &&
"longitude" in parsed
) {
const p = parsed as { latitude: number; longitude: number; accuracy?: number };
map = <LocationMap latitude={p.latitude} longitude={p.longitude} accuracy={p.accuracy} />;
}
let resultText = "";
if (result) {
const c = result.content;
resultText = typeof c === "string" ? c : JSON.stringify(c);
}
return (
<div className="tool-card">
<header>
<span aria-hidden="true">{icon}</span>
<span>{pending ? `Running ${call.name}` : title}</span>
</header>
<pre>{JSON.stringify(call.args, null, 2)}</pre>
{map}
{resultText ? (
<pre>
{map ? "---\n" : ""}
{resultText}
</pre>
) : null}
</div>
);
}
export function App() {
const [apiUrl, setApiUrl] = useState("http://127.0.0.1:2024");
const tools = useMemo(() => [browserNavigateClient, geolocationGetClient], []);
const stream = useStream({
apiUrl: apiUrl.replace(/\/$/, ""),
assistantId: ASSISTANT_ID,
tools,
});
const handleSubmit = useCallback(
(text: string) => {
void stream.submit({
messages: [{ type: "human" as const, content: text }],
});
},
[stream],
);
return (
<>
<h1>Headless tool (useStream)</h1>
<p className="hint">
Start the graph API from the parent directory:{" "}
<code>uv run langgraph dev --config langgraph.json --no-browser</code>. Tool name{" "}
<code>browser_navigate</code> and <code>geolocation_get</code> match <code>agent.py</code>
; the client runs implementations passed to <code>useStream</code> (see{" "}
<code>ui-playground/.../headless-tools</code>). Geolocation uses OpenStreetMap embeds.
</p>
<div className="row">
<div style={{ flex: 1, minWidth: "12rem" }}>
<label htmlFor="apiUrl">LangGraph API base URL</label>
<input
id="apiUrl"
type="url"
value={apiUrl}
onChange={(e) => {
setApiUrl(e.target.value);
}}
autoComplete="off"
/>
</div>
{stream.messages.length > 0 ? (
<button
type="button"
className="secondary"
onClick={() => {
stream.switchThread(null);
}}
>
New thread
</button>
) : null}
</div>
<label htmlFor="msg">Message</label>
<textarea
id="msg"
rows={3}
defaultValue={PRESETS[0]}
onKeyDown={(e) => {
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
e.preventDefault();
handleSubmit((e.target as HTMLTextAreaElement).value);
}
}}
/>
<button
type="button"
disabled={stream.isLoading}
onClick={() => {
const ta = document.getElementById("msg") as HTMLTextAreaElement | null;
handleSubmit(ta?.value ?? "");
}}
>
{stream.isLoading ? "Running…" : "Send"}
</button>
{stream.messages.length === 0 ? (
<div style={{ marginTop: "1rem" }}>
<span className="hint">Try a preset:</span>
<div style={{ display: "flex", flexWrap: "wrap", gap: "0.5rem", marginTop: "0.5rem" }}>
{PRESETS.map((p) => (
<button
key={p}
type="button"
className="secondary"
style={{ width: "auto" }}
onClick={() => {
handleSubmit(p);
}}
>
{p.slice(0, 42)}
{p.length > 42 ? "…" : ""}
</button>
))}
</div>
</div>
) : null}
<div style={{ marginTop: "1.25rem" }}>
{stream.messages.map((msg, idx) => {
if (HumanMessage.isInstance(msg)) {
return (
<div key={msg.id ?? idx} className="bubble-user">
{msg.text}
</div>
);
}
if (AIMessage.isInstance(msg)) {
const msgToolCalls = stream.toolCalls.filter((tc) =>
msg.tool_calls?.some((t) => t.id === tc.call.id),
);
if (msgToolCalls.length > 0) {
return (
<div key={msg.id ?? idx}>
{msgToolCalls.map((tc) => (
<HeadlessToolCallCard
key={tc.id}
toolCall={tc as ToolCallWithResult<DefaultToolCall>}
/>
))}
</div>
);
}
if (!msg.text) return null;
return (
<div key={msg.id ?? idx} className="bubble-ai">
{msg.text}
</div>
);
}
return null;
})}
{stream.isLoading &&
!stream.messages.some((m) => AIMessage.isInstance(m) && m.text) &&
stream.toolCalls.length === 0 && <div className="typing">Thinking</div>}
{stream.error ? (
<div className="error" role="alert">
{stream.error instanceof Error ? stream.error.message : String(stream.error)}
</div>
) : null}
</div>
</>
);
}

View File

@@ -0,0 +1,40 @@
/**
* OpenStreetMap embed (same approach as ui-playground headless-tools/components.tsx).
*/
type LocationMapProps = {
latitude: number;
longitude: number;
accuracy?: number;
};
export function LocationMap({ latitude, longitude, accuracy }: LocationMapProps) {
const delta = 0.005;
const bbox = `${longitude - delta},${latitude - delta},${longitude + delta},${latitude + delta}`;
const src = `https://www.openstreetmap.org/export/embed.html?bbox=${bbox}&layer=mapnik&marker=${latitude},${longitude}`;
const externalHref = `https://www.openstreetmap.org/?mlat=${latitude}&mlon=${longitude}#map=16/${latitude}/${longitude}`;
return (
<div className="location-map">
<div className="location-map-frame">
<iframe
src={src}
title="Your location on OpenStreetMap"
loading="lazy"
referrerPolicy="no-referrer"
/>
</div>
<div className="location-map-meta">
<span className="location-map-coords">
{latitude.toFixed(5)}, {longitude.toFixed(5)}
{accuracy !== undefined ? (
<span className="location-map-acc"> ±{Math.round(accuracy)} m</span>
) : null}
</span>
<a href={externalHref} target="_blank" rel="noopener noreferrer" className="location-map-link">
Open in OSM
</a>
</div>
</div>
);
}

View File

@@ -0,0 +1,181 @@
:root {
font-family: system-ui, sans-serif;
line-height: 1.45;
color: #e8e8ef;
background: #12141a;
}
body {
margin: 0;
}
#root {
max-width: 52rem;
margin: 0 auto;
padding: 2rem 1rem;
}
h1 {
font-weight: 600;
font-size: 1.25rem;
margin: 0 0 0.5rem;
}
.hint {
font-size: 0.85rem;
color: #8b92a6;
margin: 0 0 1rem;
}
label {
display: block;
margin-top: 1rem;
font-size: 0.85rem;
color: #a8adbc;
}
input,
textarea,
button {
width: 100%;
box-sizing: border-box;
margin-top: 0.35rem;
padding: 0.5rem 0.65rem;
border-radius: 6px;
border: 1px solid #2c3140;
background: #1a1d26;
color: #e8e8ef;
font: inherit;
}
button {
cursor: pointer;
background: #3b5bdb;
border-color: #3b5bdb;
font-weight: 500;
margin-top: 0.75rem;
}
button.secondary {
background: #2c3140;
border-color: #3d4455;
}
button:disabled {
opacity: 0.55;
cursor: not-allowed;
}
.row {
display: flex;
gap: 0.75rem;
flex-wrap: wrap;
align-items: flex-end;
}
.row input {
flex: 1;
min-width: 12rem;
}
.error {
margin-top: 1rem;
padding: 0.75rem 1rem;
border-radius: 8px;
border: 1px solid rgba(220, 38, 38, 0.45);
background: rgba(127, 29, 29, 0.25);
color: #fecaca;
font-size: 0.875rem;
}
.tool-card {
border-radius: 8px;
border: 1px solid #2c3140;
background: #1a1d26;
padding: 0.75rem 1rem;
margin: 0.5rem 0;
}
.tool-card header {
display: flex;
align-items: center;
gap: 0.5rem;
font-size: 0.9rem;
font-weight: 500;
}
.tool-card pre {
margin: 0.5rem 0 0;
font-size: 0.75rem;
white-space: pre-wrap;
word-break: break-word;
color: #c8ccd8;
}
.bubble-user {
background: #2c3140;
border-radius: 8px;
padding: 0.65rem 0.85rem;
margin: 0.5rem 0;
font-size: 0.9rem;
}
.bubble-ai {
background: #1e2433;
border: 1px solid #2c3140;
border-radius: 8px;
padding: 0.65rem 0.85rem;
margin: 0.5rem 0;
font-size: 0.9rem;
}
.typing {
opacity: 0.7;
font-size: 0.85rem;
margin: 0.5rem 0;
}
.location-map {
margin-top: 0.5rem;
}
.location-map-frame {
overflow: hidden;
border-radius: 8px;
border: 1px solid #2c3140;
height: 200px;
}
.location-map-frame iframe {
width: 100%;
height: 100%;
border: 0;
}
.location-map-meta {
display: flex;
align-items: center;
justify-content: space-between;
flex-wrap: wrap;
gap: 0.5rem;
margin-top: 0.5rem;
font-size: 0.75rem;
color: #a8adbc;
}
.location-map-coords {
font-family: ui-monospace, monospace;
}
.location-map-acc {
opacity: 0.75;
}
.location-map-link {
color: #7c9cff;
text-decoration: none;
}
.location-map-link:hover {
text-decoration: underline;
}

View File

@@ -0,0 +1,16 @@
import { StrictMode } from "react";
import { createRoot } from "react-dom/client";
import { App } from "./App";
import "./index.css";
const root = document.getElementById("root");
if (!root) {
throw new Error("Missing #root");
}
createRoot(root).render(
<StrictMode>
<App />
</StrictMode>,
);

View File

@@ -0,0 +1,97 @@
/**
* Headless tool definitions + client implementations.
*
* Mirrors `agent.py`: schema-only headless tools; execution runs here via
* `useStream({ tools: [...] })`.
*
* Pattern: https://github.com/langchain-ai/langchainjs (headless `tool` + `.implement()`)
*/
import { tool } from "langchain";
import { z } from "zod";
/** Must match `browser_navigate` in `agent.py`. */
export const browserNavigate = tool({
name: "browser_navigate",
description:
"Navigate a headless browser to a URL. The real browser runs out-of-process; " +
"the graph pauses until the client resumes with a result.",
schema: z.object({
url: z.string().describe("URL to open in the headless browser."),
}),
});
/**
* Client-side implementation: try `fetch` for same-origin or CORS-friendly URLs;
* otherwise return a simulated result (many sites block browser cross-origin fetch).
*/
/** Must match `geolocation_get` in `agent.py`. */
export const geolocationGet = tool({
name: "geolocation_get",
description:
"Get the user's current GPS coordinates using the browser's Geolocation API. " +
"The client can show the position on OpenStreetMap.",
schema: z.object({
high_accuracy: z
.boolean()
.optional()
.describe("Request high-accuracy GPS when supported."),
}),
});
export const geolocationGetClient = geolocationGet.implement(async ({ high_accuracy }) => {
if (!navigator.geolocation) {
return JSON.stringify({
success: false,
message: "Geolocation is not supported by this browser.",
});
}
const position = await new Promise<GeolocationPosition>((resolve, reject) => {
navigator.geolocation.getCurrentPosition(resolve, reject, {
enableHighAccuracy: high_accuracy ?? true,
timeout: 10_000,
maximumAge: 5 * 60 * 1000,
});
});
const { latitude, longitude, accuracy } = position.coords;
const timestamp = new Date(position.timestamp).toISOString();
return JSON.stringify({
success: true,
latitude,
longitude,
accuracy,
timestamp,
message: `Location: ${latitude.toFixed(5)}, ${longitude.toFixed(5)}${Math.round(accuracy)} m)`,
});
});
export const browserNavigateClient = browserNavigate.implement(async ({ url }) => {
let parsed: URL;
try {
parsed = new URL(url);
} catch {
return JSON.stringify({ ok: false, error: "Invalid URL", url });
}
try {
const res = await fetch(parsed.href, { mode: "cors", credentials: "omit" });
const text = await res.text();
const titleMatch = text.match(/<title>([^<]*)<\/title>/i);
return JSON.stringify({
ok: true,
url: parsed.href,
status: res.status,
title: titleMatch?.[1]?.trim() ?? null,
});
} catch {
return JSON.stringify({
ok: true,
url: parsed.href,
simulated: true,
note:
"Could not fetch page (often cross-origin). Simulated successful navigation for E2E.",
});
}
});

View File

@@ -0,0 +1,20 @@
{
"compilerOptions": {
"target": "ES2022",
"useDefineForClassFields": true,
"lib": ["ES2022", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,
"moduleResolution": "bundler",
"isolatedModules": true,
"moduleDetection": "force",
"noEmit": true,
"jsx": "react-jsx",
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": ["src"]
}

View File

@@ -0,0 +1,9 @@
import react from "@vitejs/plugin-react";
import { defineConfig } from "vite";
export default defineConfig({
plugins: [react()],
server: {
port: 8765,
},
});

View File

@@ -0,0 +1,8 @@
{
"$schema": "https://langgra.ph/schema.json",
"dependencies": ["."],
"graphs": {
"agent": "./agent.py:graph"
},
"env": ".env"
}

View File

@@ -0,0 +1,24 @@
[build-system]
requires = ["setuptools>=61", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "headless-langgraph-e2e"
version = "0.1.0"
description = "Local LangGraph dev server + headless tool demo (create_agent)."
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"langchain",
"langchain-openai>=0.3.0",
"python-dotenv>=1.0.0",
]
[tool.uv.sources]
langchain = { path = "../../libs/langchain_v1", editable = true }
[dependency-groups]
dev = ["langgraph-cli[inmem]>=0.4.14"]
[tool.setuptools]
py-modules = ["agent"]

2241
examples/headless-langgraph-e2e/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -5,13 +5,15 @@ from langchain_core.tools import (
InjectedToolArg,
InjectedToolCallId,
ToolException,
tool,
)
from langchain.tools.headless import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolRuntime
__all__ = [
"HEADLESS_TOOL_METADATA_KEY",
"BaseTool",
"HeadlessTool",
"InjectedState",
"InjectedStore",
"InjectedToolArg",

View File

@@ -0,0 +1,278 @@
"""Headless tools: schema-only tools that interrupt for out-of-process execution.
Mirrors the LangChain.js `tool` overload from
https://github.com/langchain-ai/langchainjs/pull/10430 — tools defined with
`name`, `description`, and `args_schema` only. When invoked inside a LangGraph
agent, execution pauses with an interrupt payload so a client can run the
implementation and resume.
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable # noqa: TC003
from typing import Annotated, Any, Literal, cast, overload
from langchain_core.runnables import Runnable, RunnableConfig # noqa: TC002
from langchain_core.tools import tool as core_tool
from langchain_core.tools.base import ArgsSchema, BaseTool, InjectedToolCallId
from langchain_core.tools.structured import StructuredTool
from langchain_core.utils.pydantic import is_basemodel_subclass, is_pydantic_v2_subclass
from langgraph.types import interrupt
from pydantic import BaseModel, create_model
# Metadata on the tool definition for introspection (e.g. listing tools, bind_tools).
# This does not appear as LLM token stream chunks. When a headless tool runs inside a
# LangGraph graph, the interrupt value (see `_headless_interrupt_payload`) is what
# surfaces to clients during streamed graph execution (interrupt events).
HEADLESS_TOOL_METADATA_KEY = "headless_tool"
def _args_schema_with_injected_tool_call_id(
tool_name: str,
args_schema: type[BaseModel],
) -> type[BaseModel]:
"""Extend a user args model with an injected `tool_call_id` field.
The field is stripped from the model-facing tool schema but populated at
invocation time so interrupt payloads can include the tool call id.
Args:
tool_name: Base name for the generated schema type.
args_schema: Original Pydantic model for tool arguments.
Returns:
A new model type including `tool_call_id` injection.
"""
model_name = f"{tool_name}HeadlessInput"
return create_model(
model_name,
__base__=args_schema,
tool_call_id=(
Annotated[str | None, InjectedToolCallId],
None,
),
)
def _headless_interrupt_payload(tool_name: str, **kwargs: Any) -> Any:
"""Build the LangGraph interrupt value for a headless tool call."""
tool_call_id = kwargs.pop("tool_call_id", None)
return interrupt(
{
"type": "tool",
"tool_call": {
"id": tool_call_id,
"name": tool_name,
"args": kwargs,
},
}
)
def _make_headless_sync(tool_name: str) -> Callable[..., Any]:
def _headless_sync(
_config: RunnableConfig,
**kwargs: Any,
) -> Any:
return _headless_interrupt_payload(tool_name, **kwargs)
return _headless_sync
def _make_headless_coroutine(tool_name: str) -> Callable[..., Awaitable[Any]]:
async def _headless_coroutine(
_config: RunnableConfig,
**kwargs: Any,
) -> Any:
return _headless_interrupt_payload(tool_name, **kwargs)
return _headless_coroutine
class HeadlessTool(StructuredTool):
"""Structured tool that interrupts instead of executing locally."""
def _create_headless_tool(
*,
name: str,
description: str,
args_schema: ArgsSchema,
return_direct: bool = False,
response_format: Literal["content", "content_and_artifact"] = "content",
extras: dict[str, Any] | None = None,
) -> HeadlessTool:
"""Instantiate a headless tool. Prefer the public `tool()` overload for new code.
Raises:
TypeError: If `args_schema` is not a Pydantic model or dict.
"""
if isinstance(args_schema, dict):
schema_for_tool: ArgsSchema = args_schema
elif is_basemodel_subclass(args_schema):
if is_pydantic_v2_subclass(args_schema):
schema_for_tool = _args_schema_with_injected_tool_call_id(name, args_schema)
else:
schema_for_tool = args_schema
else:
msg = "args_schema must be a Pydantic BaseModel subclass or a dict schema."
raise TypeError(msg)
metadata = {HEADLESS_TOOL_METADATA_KEY: True}
sync_fn = _make_headless_sync(name)
coroutine = _make_headless_coroutine(name)
return HeadlessTool(
name=name,
func=sync_fn,
coroutine=coroutine,
description=description,
args_schema=schema_for_tool,
return_direct=return_direct,
response_format=response_format,
metadata=metadata,
extras=extras,
)
@overload
def tool(
*,
name: str,
description: str,
args_schema: ArgsSchema,
return_direct: bool = False,
response_format: Literal["content", "content_and_artifact"] = "content",
extras: dict[str, Any] | None = None,
) -> HeadlessTool: ...
@overload
def tool(
name_or_callable: str,
runnable: Runnable[Any, Any],
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: Callable[..., Any],
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> BaseTool: ...
@overload
def tool(
name_or_callable: str,
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]: ...
@overload
def tool(
*,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]: ...
def tool(
name_or_callable: str | Callable[..., Any] | None = None,
runnable: Runnable[Any, Any] | None = None,
*args: Any,
name: str | None = None,
description: str | None = None,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
extras: dict[str, Any] | None = None,
) -> BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool] | HeadlessTool:
"""Create a tool, including headless (interrupting) tools.
This is the supported entry point for headless tools: use keyword-only
`name`, `description`, and `args_schema` with no implementation callable to get
a `HeadlessTool` that calls LangGraph `interrupt` on both sync `invoke` and
async `ainvoke`. Otherwise delegates to `langchain_core.tools.tool`.
Args:
name_or_callable: Passed through to core `tool` when not using headless mode.
runnable: Passed through to core `tool`.
name: Tool name (headless overload only).
description: Tool description.
args_schema: Argument schema (`BaseModel` or JSON-schema dict).
return_direct: Whether to return directly from the tool node.
infer_schema: Whether to infer schema from a decorated function (core `tool`).
response_format: Core tool response format.
parse_docstring: Core `tool` docstring parsing flag.
error_on_invalid_docstring: Core `tool` flag.
extras: Optional provider-specific extras.
Returns:
A `HeadlessTool`, a `BaseTool`, or a decorator factory from core `tool`.
"""
if (
len(args) == 0
and name_or_callable is None
and runnable is None
and name is not None
and description is not None
and args_schema is not None
):
return _create_headless_tool(
name=name,
description=description,
args_schema=args_schema,
return_direct=return_direct,
response_format=response_format,
extras=extras,
)
delegated = core_tool(
cast("Any", name_or_callable),
cast("Any", runnable),
*args,
description=description,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
extras=extras,
)
return cast(
"BaseTool | Callable[[Callable[..., Any] | Runnable[Any, Any]], BaseTool]",
delegated,
)

View File

@@ -0,0 +1,133 @@
"""Tests for headless (interrupting) tools."""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import patch
import pytest
from pydantic import BaseModel, Field
from langchain.tools import HEADLESS_TOOL_METADATA_KEY, HeadlessTool, tool
class _MessageArgs(BaseModel):
message: str = Field(..., description="A message.")
def test_headless_tool_properties() -> None:
t = tool(
name="test_tool",
description="A test headless tool.",
args_schema=_MessageArgs,
)
assert isinstance(t, HeadlessTool)
assert t.name == "test_tool"
assert t.description == "A test headless tool."
assert t.metadata == {HEADLESS_TOOL_METADATA_KEY: True}
def test_tool_headless_overload() -> None:
t = tool(
name="from_overload",
description="via unified tool()",
args_schema=_MessageArgs,
)
assert isinstance(t, HeadlessTool)
assert t.name == "from_overload"
def test_tool_normal_still_returns_structured_tool() -> None:
def get_weather(city: str) -> str:
"""Return a fake forecast for the city."""
return f"sunny in {city}"
w = tool(get_weather)
assert not isinstance(w, HeadlessTool)
assert w.name == "get_weather"
@pytest.mark.asyncio
async def test_headless_coroutine_calls_interrupt() -> None:
ht = tool(
name="interrupt_me",
description="d",
args_schema=_MessageArgs,
)
with patch("langchain.tools.headless.interrupt") as mock_interrupt:
mock_interrupt.return_value = "resumed"
result = await ht.ainvoke(
{
"type": "tool_call",
"name": "interrupt_me",
"id": "call-1",
"args": {"message": "hi"},
}
)
mock_interrupt.assert_called_once()
payload = mock_interrupt.call_args[0][0]
assert payload["type"] == "tool"
assert payload["tool_call"]["id"] == "call-1"
assert payload["tool_call"]["name"] == "interrupt_me"
assert payload["tool_call"]["args"] == {"message": "hi"}
assert getattr(result, "content", result) == "resumed"
def test_headless_sync_invoke_calls_interrupt() -> None:
"""Sync `invoke` must work (StructuredTool previously had no sync path)."""
ht = tool(
name="sync_interrupt",
description="d",
args_schema=_MessageArgs,
)
with patch("langchain.tools.headless.interrupt") as mock_interrupt:
mock_interrupt.return_value = "ok"
result = ht.invoke(
{
"type": "tool_call",
"name": "sync_interrupt",
"id": "cid-9",
"args": {"message": "sync"},
}
)
mock_interrupt.assert_called_once()
payload = mock_interrupt.call_args[0][0]
assert payload["tool_call"]["id"] == "cid-9"
assert getattr(result, "content", result) == "ok"
def test_headless_dict_schema_has_metadata() -> None:
schema: dict[str, Any] = {
"type": "object",
"properties": {"q": {"type": "string"}},
"required": ["q"],
}
ht = tool(
name="dict_tool",
description="Uses JSON schema.",
args_schema=schema,
)
assert ht.metadata == {HEADLESS_TOOL_METADATA_KEY: True}
assert "q" in ht.args
def test_invoke_without_graph_context_errors() -> None:
ht = tool(
name="t",
description="d",
args_schema=_MessageArgs,
)
with pytest.raises((RuntimeError, KeyError)):
asyncio.run(
ht.ainvoke(
{
"type": "tool_call",
"name": "t",
"id": "x",
"args": {"message": "m"},
}
)
)

View File

@@ -1,7 +1,9 @@
from langchain import tools
EXPECTED_ALL = {
"HEADLESS_TOOL_METADATA_KEY",
"BaseTool",
"HeadlessTool",
"InjectedState",
"InjectedStore",
"InjectedToolArg",