Compare commits

...

1 Commits

Author SHA1 Message Date
jacoblee93
4e0b69f506 Tag middleware runs with ls_agent_type 2026-04-17 22:02:09 -07:00
2 changed files with 177 additions and 98 deletions

View File

@@ -896,9 +896,11 @@ def create_agent(
wrap_tool_call_wrapper = None
if middleware_w_wrap_tool_call:
wrappers = [
traceable(name=f"{m.name}.wrap_tool_call", process_inputs=_scrub_inputs)(
m.wrap_tool_call
)
traceable(
name=f"{m.name}.wrap_tool_call",
process_inputs=_scrub_inputs,
metadata={"ls_agent_type": "middleware"},
)(m.wrap_tool_call)
for m in middleware_w_wrap_tool_call
]
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
@@ -917,9 +919,11 @@ def create_agent(
awrap_tool_call_wrapper = None
if middleware_w_awrap_tool_call:
async_wrappers = [
traceable(name=f"{m.name}.awrap_tool_call", process_inputs=_scrub_inputs)(
m.awrap_tool_call
)
traceable(
name=f"{m.name}.awrap_tool_call",
process_inputs=_scrub_inputs,
metadata={"ls_agent_type": "middleware"},
)(m.awrap_tool_call)
for m in middleware_w_awrap_tool_call
]
awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
@@ -1005,9 +1009,11 @@ def create_agent(
wrap_model_call_handler = None
if middleware_w_wrap_model_call:
sync_handlers = [
traceable(name=f"{m.name}.wrap_model_call", process_inputs=_scrub_inputs)(
m.wrap_model_call
)
traceable(
name=f"{m.name}.wrap_model_call",
process_inputs=_scrub_inputs,
metadata={"ls_agent_type": "middleware"},
)(m.wrap_model_call)
for m in middleware_w_wrap_model_call
]
wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
@@ -1016,9 +1022,11 @@ def create_agent(
awrap_model_call_handler = None
if middleware_w_awrap_model_call:
async_handlers = [
traceable(name=f"{m.name}.awrap_model_call", process_inputs=_scrub_inputs)(
m.awrap_model_call
)
traceable(
name=f"{m.name}.awrap_model_call",
process_inputs=_scrub_inputs,
metadata={"ls_agent_type": "middleware"},
)(m.awrap_model_call)
for m in middleware_w_awrap_model_call
]
awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)

View File

@@ -29,11 +29,19 @@ from langsmith import Client
from langsmith.run_helpers import tracing_context
from langchain.agents import create_agent
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
)
from langchain.tools import InjectedState, ToolRuntime
from tests.unit_tests.agents.model import FakeToolCallingModel
if TYPE_CHECKING:
from collections.abc import Callable
from langgraph.runtime import Runtime
@@ -849,107 +857,170 @@ async def test_combined_injected_state_runtime_store_async() -> None:
assert injected_data["store_write_success"] is True
def test_ls_agent_type_is_trace_only_metadata() -> None:
"""Test that ls_agent_type is added to metadata on tracing only, not in streamed chunks."""
# Capture metadata from regular callback handler (simulates streamed metadata)
captured_callback_metadata: list[dict[str, Any]] = []
# ---------------------------------------------------------------------------
# ls_agent_type tracing metadata
# ---------------------------------------------------------------------------
class CaptureHandler(BaseCallbackHandler):
def on_chain_start(
class _CaptureCallbackHandler(BaseCallbackHandler):
"""Records metadata observed on every ``on_chain_start`` callback."""
def __init__(self) -> None:
self.captured: list[dict[str, Any]] = []
def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: str,
parent_run_id: str | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.captured.append(
{
"name": kwargs.get("name") or (serialized or {}).get("name"),
"tags": tags,
"metadata": metadata or {},
}
)
def _build_mock_langsmith_client() -> tuple[MagicMock, Client]:
"""Return a (session, client) pair where the session records tracing POSTs."""
mock_session = MagicMock()
mock_client = Client(session=mock_session, api_key="test", auto_batch_tracing=False)
return mock_session, mock_client
def _posted_runs(mock_session: MagicMock) -> list[dict[str, Any]]:
"""Extract the run dicts POSTed to the LangSmith API by the mock session."""
posts: list[dict[str, Any]] = []
for call in mock_session.request.mock_calls:
if call.args and call.args[0] == "POST":
body = json.loads(call.kwargs["data"])
if "post" in body:
posts.extend(body["post"])
else:
posts.append(body)
return posts
def _run_metadata(post: dict[str, Any]) -> dict[str, Any]:
return post.get("extra", {}).get("metadata", {}) or {}
def test_ls_agent_type_root_is_trace_only_metadata() -> None:
"""``ls_agent_type='root'`` reaches the LangSmith tracer but not callback metadata."""
handler = _CaptureCallbackHandler()
mock_session, mock_client = _build_mock_langsmith_client()
agent = create_agent(
model=FakeToolCallingModel(tool_calls=[[], []]),
tools=[],
system_prompt="You are a helpful assistant.",
)
with tracing_context(client=mock_client, enabled=True):
agent.invoke(
{"messages": [HumanMessage("hi?")]},
config={"callbacks": [handler]},
)
# ls_agent_type must not leak into callback metadata.
assert handler.captured, "expected on_chain_start to fire at least once"
for entry in handler.captured:
assert entry["metadata"].get("ls_agent_type") is None, (
f"ls_agent_type leaked into callback metadata: {entry['metadata']}"
)
# ls_agent_type='root' must reach the tracer on the root run.
posts = _posted_runs(mock_session)
assert posts, "expected at least one LangSmith POST"
assert _run_metadata(posts[0]).get("ls_agent_type") == "root"
def test_ls_agent_type_is_overridable_via_configurable() -> None:
"""A caller can override ``ls_agent_type`` (and add keys) via ``configurable``."""
mock_session, mock_client = _build_mock_langsmith_client()
agent = create_agent(
model=FakeToolCallingModel(tool_calls=[[], []]),
tools=[],
system_prompt="You are a helpful assistant.",
)
with tracing_context(client=mock_client, enabled=True):
agent.invoke(
{"messages": [HumanMessage("hi?")]},
config={
"configurable": {
"ls_agent_type": "subagent",
"custom_key": "custom_value",
}
},
)
posts = _posted_runs(mock_session)
assert posts, "expected at least one LangSmith POST"
root_metadata = _run_metadata(posts[0])
assert root_metadata.get("ls_agent_type") == "subagent"
# Extra configurable keys also flow into tracer metadata.
assert root_metadata.get("custom_key") == "custom_value"
def test_ls_agent_type_middleware_is_trace_only_metadata() -> None:
"""Middleware traceable runs are tagged with ``ls_agent_type='middleware'``.
The tag is attached via the ``metadata=`` argument of langsmith's
``traceable`` decorator, which routes it to ``run.extra.metadata`` for
LangSmith only -- it must not leak into on_chain_start callback metadata.
"""
class PassthroughMiddleware(AgentMiddleware):
name = "test-passthrough"
def wrap_model_call(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: str,
parent_run_id: str | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
captured_callback_metadata.append({"tags": tags, "metadata": metadata})
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
# Create a mock client to capture what gets sent to LangSmith
mock_session = MagicMock()
mock_client = Client(session=mock_session, api_key="test", auto_batch_tracing=False)
handler = _CaptureCallbackHandler()
mock_session, mock_client = _build_mock_langsmith_client()
agent = create_agent(
model=FakeToolCallingModel(tool_calls=[[], []]),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[PassthroughMiddleware()],
)
# Use tracing_context to enable tracing with the mock client
with tracing_context(client=mock_client, enabled=True):
agent.invoke(
{"messages": [HumanMessage("hi?")]},
config={"callbacks": [CaptureHandler()]},
config={"callbacks": [handler]},
)
# Verify that ls_agent_type is NOT in the regular callback metadata
# (it should only go to the tracer via langsmith_inheritable_metadata)
assert len(captured_callback_metadata) > 0
for captured in captured_callback_metadata:
metadata = captured.get("metadata") or {}
assert metadata.get("ls_agent_type") is None, (
f"ls_agent_type should not be in callback metadata, but got: {metadata}"
# (1) ls_agent_type='middleware' must not leak into callback metadata.
for entry in handler.captured:
assert entry["metadata"].get("ls_agent_type") != "middleware", (
f"ls_agent_type='middleware' leaked into callback metadata for "
f"run {entry['name']!r}: {entry['metadata']}"
)
# Verify that ls_agent_type IS in the tracer metadata (sent to LangSmith)
# Get the POST requests to the LangSmith API
posts = []
for call in mock_session.request.mock_calls:
if call.args and call.args[0] == "POST":
body = json.loads(call.kwargs["data"])
if "post" in body:
posts.extend(body["post"])
else:
posts.append(body)
assert len(posts) >= 1
# Find the root run (the agent execution)
root_post = posts[0]
metadata = root_post.get("extra", {}).get("metadata", {})
assert metadata.get("ls_agent_type") == "root", (
f"ls_agent_type should be 'root' in tracer metadata, but got: {metadata}"
# (2) ls_agent_type='middleware' must reach the LangSmith tracer, on a run
# named after the middleware's traceable (e.g. 'test-passthrough.wrap_model_call').
posts = _posted_runs(mock_session)
middleware_posts = [p for p in posts if _run_metadata(p).get("ls_agent_type") == "middleware"]
assert middleware_posts, (
f"expected a LangSmith post with ls_agent_type='middleware'; "
f"saw metadatas: {[_run_metadata(p) for p in posts]}"
)
def test_ls_agent_type_is_overridable() -> None:
"""Test that ls_agent_type can be overridden via configurable in invoke config."""
# Create a mock client to capture what gets sent to LangSmith
mock_session = MagicMock()
mock_client = Client(session=mock_session, api_key="test", auto_batch_tracing=False)
agent = create_agent(
model=FakeToolCallingModel(tool_calls=[[], []]),
tools=[],
system_prompt="You are a helpful assistant.",
)
# Use tracing_context to enable tracing with the mock client
with tracing_context(client=mock_client, enabled=True):
agent.invoke(
{"messages": [HumanMessage("hi?")]},
config={"configurable": {"ls_agent_type": "subagent", "custom_key": "custom_value"}},
)
# Verify that ls_agent_type is overridden and configurable is merged in the tracer metadata
posts = []
for call in mock_session.request.mock_calls:
if call.args and call.args[0] == "POST":
body = json.loads(call.kwargs["data"])
if "post" in body:
posts.extend(body["post"])
else:
posts.append(body)
assert len(posts) >= 1
root_post = posts[0]
metadata = root_post.get("extra", {}).get("metadata", {})
assert metadata.get("ls_agent_type") == "subagent", (
f"ls_agent_type should be 'subagent' in tracer metadata, but got: {metadata}"
)
# Verify that the additional configurable key is merged into metadata
assert metadata.get("custom_key") == "custom_value", (
f"custom_key should be 'custom_value' in tracer metadata, but got: {metadata}"
assert any("test-passthrough" in (p.get("name") or "") for p in middleware_posts), (
f"expected a middleware run named like 'test-passthrough.wrap_model_call', "
f"got: {[p.get('name') for p in middleware_posts]}"
)