This commit is contained in:
Sydney Runkle
2026-03-03 14:38:53 -08:00
parent c52c5d5bf0
commit d7bbc0ae90
4 changed files with 31 additions and 83 deletions

View File

@@ -9,7 +9,6 @@ from langchain_core.tools import InjectedToolCallId, tool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime
from langgraph.types import GraphOutput
from pydantic import BaseModel, Field
from syrupy.assertion import SnapshotAssertion
from typing_extensions import override
@@ -96,10 +95,7 @@ def test_create_agent_invoke(
thread1 = {"configurable": {"thread_id": "1"}}
result = agent_one.invoke({"messages": ["hello"]}, thread1)
# v2 stream format returns GraphOutput; unwrap to get the dict
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict == {
assert result.value == {
"messages": [
_AnyIdHumanMessage(content="hello"),
AIMessage(
@@ -201,9 +197,7 @@ def test_create_agent_jump(
thread1 = {"configurable": {"thread_id": "1"}}
result = agent_one.invoke({"messages": []}, thread1)
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict == {"messages": []}
assert result.value == {"messages": []}
assert calls == ["NoopSeven.before_model", "NoopEight.before_model"]
@@ -701,9 +695,7 @@ async def test_create_agent_async_jump() -> None:
result = await agent.ainvoke({"messages": []})
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict == {"messages": []}
assert result.value == {"messages": []}
assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"]

View File

@@ -11,7 +11,7 @@ import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.errors import InvalidUpdateError
from langgraph.types import Command, GraphOutput
from langgraph.types import Command
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware.types import (
@@ -630,10 +630,8 @@ class TestComposition:
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict.get("structured_response") == {"key": "value"}
messages = result_dict["messages"]
assert result.value.get("structured_response") == {"key": "value"}
messages = result.value["messages"]
assert len(messages) == 2
assert messages[1].content == "Hello"

View File

@@ -1,16 +1,13 @@
"""Debug: trace exactly what happens when middleware processes parallel calls.
"""Test that middleware correctly handles parallel tool calls with limits.
Instruments after_model and model_to_tools to see if the ToolMessage injected
by middleware is visible to the routing edge.
Verifies that when middleware blocks some parallel tool calls, only the
permitted calls execute and interrupts propagate correctly.
"""
from unittest.mock import patch
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langchain_core.messages import HumanMessage, ToolCall
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver, MemorySaver
from langgraph.graph import MessagesState, StateGraph
from langgraph.types import GraphOutput, interrupt
from langgraph.checkpoint.memory import MemorySaver
from langgraph.types import interrupt
from langchain.agents.factory import create_agent
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
@@ -47,61 +44,22 @@ def test_instrument_middleware_and_routing() -> None:
fruit_mw = ToolCallLimitMiddleware(tool_name="ask_fruit_expert", run_limit=1)
veggie_mw = ToolCallLimitMiddleware(tool_name="ask_veggie_expert", run_limit=1)
# Wrap after_model to trace its return value
original_after_model = ToolCallLimitMiddleware.after_model
def traced_after_model(self, state, runtime):
result = original_after_model(self, state, runtime)
msgs = state.get("messages", [])
ai_msgs = [m for m in msgs if isinstance(m, AIMessage)]
tool_msgs = [m for m in msgs if isinstance(m, ToolMessage)]
print(f"\n [{self.name}].after_model:")
print(f" state has {len(msgs)} messages ({len(ai_msgs)} AI, {len(tool_msgs)} Tool)")
if ai_msgs:
last_ai = ai_msgs[-1]
print(f" last AI tool_calls: {[tc['id'] for tc in last_ai.tool_calls]}")
print(f" returning: {result}")
if result and "messages" in result:
for m in result["messages"]:
if isinstance(m, ToolMessage):
print(f" -> injecting ToolMessage(call_id={m.tool_call_id}, status={m.status})")
return result
checkpointer = MemorySaver()
with patch.object(ToolCallLimitMiddleware, "after_model", traced_after_model):
agent = create_agent(
model=model,
tools=[ask_fruit_expert, ask_veggie_expert],
middleware=[fruit_mw, veggie_mw],
checkpointer=checkpointer,
)
agent = create_agent(
model=model,
tools=[ask_fruit_expert, ask_veggie_expert],
middleware=[fruit_mw, veggie_mw],
checkpointer=checkpointer,
)
# Also trace the model_to_tools edge
# Get the compiled graph and print its structure
print("\n=== Graph nodes ===")
for name in agent.nodes:
print(f" {name}")
config = {"configurable": {"thread_id": "debug1"}}
result = agent.invoke(
{"messages": [HumanMessage("Tell me about apples and bananas")]},
config,
)
# v2 stream format: unwrap GraphOutput
assert isinstance(result, GraphOutput)
result_dict = result.value
interrupts = list(result.interrupts)
print(f"\n=== Results ===")
print(f"call_log: {call_log}")
print(f"interrupts: {len(interrupts)}")
tool_msgs = [m for m in result_dict["messages"] if isinstance(m, ToolMessage)]
error_msgs = [m for m in tool_msgs if m.status == "error"]
print(f"error tool messages: {len(error_msgs)}")
for m in error_msgs:
print(f" call_id={m.tool_call_id} content={m.content!r}")
config = {"configurable": {"thread_id": "debug1"}}
result = agent.invoke(
{"messages": [HumanMessage("Tell me about apples and bananas")]},
config,
)
assert len(call_log) == 1, f"Expected 1 execution, got {len(call_log)}: {call_log}"
assert len(interrupts) == 1, f"Expected 1 interrupt, got {len(interrupts)}"
assert len(result.interrupts) == 1, (
f"Expected 1 interrupt, got {len(result.interrupts)}"
)

View File

@@ -1983,7 +1983,7 @@ requires-dist = [
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
{ name = "langchain-together", marker = "extra == 'together'" },
{ name = "langchain-xai", marker = "extra == 'xai'" },
{ name = "langgraph", git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1" },
{ name = "langgraph", git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=sr%2F1.1-v2-default" },
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
]
provides-extras = ["community", "anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
@@ -2536,7 +2536,7 @@ wheels = [
[[package]]
name = "langgraph"
version = "1.0.10rc1"
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
@@ -2549,7 +2549,7 @@ dependencies = [
[[package]]
name = "langgraph-checkpoint"
version = "4.0.1rc4"
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fcheckpoint&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fcheckpoint&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
dependencies = [
{ name = "langchain-core" },
{ name = "ormsgpack" },
@@ -2558,7 +2558,7 @@ dependencies = [
[[package]]
name = "langgraph-prebuilt"
version = "1.0.8"
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fprebuilt&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fprebuilt&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
@@ -2567,7 +2567,7 @@ dependencies = [
[[package]]
name = "langgraph-sdk"
version = "0.3.9"
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fsdk-py&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fsdk-py&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
dependencies = [
{ name = "httpx" },
{ name = "orjson" },