feat: tracing for wrap model + tool call (#35765)

Adding tracing for `wrap_model_call` and `wrap_tool_call`
Scrubbing `request.runtime` and `handler` for now

`wrap_model_call`:
<img width="1292" height="433" alt="Screenshot 2026-03-11 at 2 22 31 PM"
src="https://github.com/user-attachments/assets/7717ef52-1498-41cf-97da-93e171377c9f"
/>


`wrap_tool_call`:
<img width="1301" height="664" alt="Screenshot 2026-03-11 at 2 22 50 PM"
src="https://github.com/user-attachments/assets/8722b28a-2482-40cf-911e-dae5cd383373"
/>
This commit is contained in:
Sydney Runkle
2026-03-11 15:57:41 -04:00
committed by GitHub
parent f9dbd22fe1
commit 25f94eecce

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import itertools
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from typing import (
TYPE_CHECKING,
Annotated,
@@ -23,6 +23,7 @@ from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
from langgraph.types import Command, Send
from langsmith import traceable
from typing_extensions import NotRequired, Required, TypedDict
from langchain.agents.middleware.types import (
@@ -36,6 +37,7 @@ from langchain.agents.middleware.types import (
OmitFromSchema,
ResponseT,
StateT_co,
ToolCallRequest,
_InputAgentState,
_OutputAgentState,
)
@@ -79,7 +81,7 @@ if TYPE_CHECKING:
from langgraph.store.base import BaseStore
from langgraph.types import Checkpointer
from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
from langchain.agents.middleware.types import ToolCallWrapper
_ModelCallHandler = Callable[
[ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]],
@@ -130,6 +132,19 @@ Option 2: Handle dynamic tools in middleware (for tools created at runtime)
return handler(request)
""".strip()
def _scrub_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
"""Remove ``runtime`` and ``handler`` from trace inputs before sending to LangSmith."""
filtered = inputs.copy()
filtered.pop("handler", None)
req = filtered.get("request")
if isinstance(req, (ModelRequest, ToolCallRequest)):
filtered["request"] = {
f.name: getattr(req, f.name) for f in fields(req) if f.name != "runtime"
}
return filtered
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
# if model profile data are not available, these models are assumed to support
# structured output
@@ -862,7 +877,12 @@ def create_agent(
# Chain all wrap_tool_call handlers into a single composed handler
wrap_tool_call_wrapper = None
if middleware_w_wrap_tool_call:
wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
wrappers = [
traceable(name=f"{m.name}.wrap_tool_call", process_inputs=_scrub_inputs)(
m.wrap_tool_call
)
for m in middleware_w_wrap_tool_call
]
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
# Collect middleware with awrap_tool_call or wrap_tool_call hooks
@@ -878,7 +898,12 @@ def create_agent(
# Chain all awrap_tool_call handlers into a single composed async handler
awrap_tool_call_wrapper = None
if middleware_w_awrap_tool_call:
async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
async_wrappers = [
traceable(name=f"{m.name}.awrap_tool_call", process_inputs=_scrub_inputs)(
m.awrap_tool_call
)
for m in middleware_w_awrap_tool_call
]
awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
# Setup tools
@@ -961,13 +986,23 @@ def create_agent(
# Compose wrap_model_call handlers into a single middleware stack (sync)
wrap_model_call_handler = None
if middleware_w_wrap_model_call:
sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
sync_handlers = [
traceable(name=f"{m.name}.wrap_model_call", process_inputs=_scrub_inputs)(
m.wrap_model_call
)
for m in middleware_w_wrap_model_call
]
wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
# Compose awrap_model_call handlers into a single middleware stack (async)
awrap_model_call_handler = None
if middleware_w_awrap_model_call:
async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
async_handlers = [
traceable(name=f"{m.name}.awrap_model_call", process_inputs=_scrub_inputs)(
m.awrap_model_call
)
for m in middleware_w_awrap_model_call
]
awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
state_schemas: set[type] = {m.state_schema for m in middleware}