From d7bbc0ae90abb88220ec19641360d1a03f0ac72d Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Tue, 3 Mar 2026 14:38:53 -0800 Subject: [PATCH] boom --- .../agents/middleware/core/test_framework.py | 14 +--- .../core/test_wrap_model_call_state_update.py | 8 +- .../test_tool_call_limit_parallel_bug.py | 82 +++++-------------- libs/langchain_v1/uv.lock | 10 +-- 4 files changed, 31 insertions(+), 83 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py index 20c93af4316..e9255ba3c49 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py @@ -9,7 +9,6 @@ from langchain_core.tools import InjectedToolCallId, tool from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.memory import InMemorySaver from langgraph.runtime import Runtime -from langgraph.types import GraphOutput from pydantic import BaseModel, Field from syrupy.assertion import SnapshotAssertion from typing_extensions import override @@ -96,10 +95,7 @@ def test_create_agent_invoke( thread1 = {"configurable": {"thread_id": "1"}} result = agent_one.invoke({"messages": ["hello"]}, thread1) - # v2 stream format returns GraphOutput; unwrap to get the dict - assert isinstance(result, GraphOutput) - result_dict = result.value - assert result_dict == { + assert result.value == { "messages": [ _AnyIdHumanMessage(content="hello"), AIMessage( @@ -201,9 +197,7 @@ def test_create_agent_jump( thread1 = {"configurable": {"thread_id": "1"}} result = agent_one.invoke({"messages": []}, thread1) - assert isinstance(result, GraphOutput) - result_dict = result.value - assert result_dict == {"messages": []} + assert result.value == {"messages": []} assert calls == ["NoopSeven.before_model", "NoopEight.before_model"] @@ -701,9 +695,7 @@ async def test_create_agent_async_jump() -> None: result = await agent.ainvoke({"messages": []}) - assert isinstance(result, GraphOutput) - result_dict = result.value - assert result_dict == {"messages": []} + assert result.value == {"messages": []} assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py index cbc7b7de0a7..2b5330ab891 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py @@ -11,7 +11,7 @@ import pytest from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from langchain_core.messages import AIMessage, HumanMessage from langgraph.errors import InvalidUpdateError -from langgraph.types import Command, GraphOutput +from langgraph.types import Command from langchain.agents import AgentState, create_agent from langchain.agents.middleware.types import ( @@ -630,10 +630,8 @@ class TestComposition: result = agent.invoke({"messages": [HumanMessage("Hi")]}) - assert isinstance(result, GraphOutput) - result_dict = result.value - assert result_dict.get("structured_response") == {"key": "value"} - messages = result_dict["messages"] + assert result.value.get("structured_response") == {"key": "value"} + messages = result.value["messages"] assert len(messages) == 2 assert messages[1].content == "Hello" diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit_parallel_bug.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit_parallel_bug.py index e7ffcdf0ad3..c4b1097a302 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit_parallel_bug.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit_parallel_bug.py @@ -1,16 +1,13 @@ -"""Debug: trace exactly what happens when middleware processes parallel calls. +"""Test that middleware correctly handles parallel tool calls with limits. -Instruments after_model and model_to_tools to see if the ToolMessage injected -by middleware is visible to the routing edge. +Verifies that when middleware blocks some parallel tool calls, only the +permitted calls execute and interrupts propagate correctly. """ -from unittest.mock import patch - -from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage +from langchain_core.messages import HumanMessage, ToolCall from langchain_core.tools import tool -from langgraph.checkpoint.memory import InMemorySaver, MemorySaver -from langgraph.graph import MessagesState, StateGraph -from langgraph.types import GraphOutput, interrupt +from langgraph.checkpoint.memory import MemorySaver +from langgraph.types import interrupt from langchain.agents.factory import create_agent from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware @@ -47,61 +44,22 @@ def test_instrument_middleware_and_routing() -> None: fruit_mw = ToolCallLimitMiddleware(tool_name="ask_fruit_expert", run_limit=1) veggie_mw = ToolCallLimitMiddleware(tool_name="ask_veggie_expert", run_limit=1) - # Wrap after_model to trace its return value - original_after_model = ToolCallLimitMiddleware.after_model - - def traced_after_model(self, state, runtime): - result = original_after_model(self, state, runtime) - msgs = state.get("messages", []) - ai_msgs = [m for m in msgs if isinstance(m, AIMessage)] - tool_msgs = [m for m in msgs if isinstance(m, ToolMessage)] - print(f"\n [{self.name}].after_model:") - print(f" state has {len(msgs)} messages ({len(ai_msgs)} AI, {len(tool_msgs)} Tool)") - if ai_msgs: - last_ai = ai_msgs[-1] - print(f" last AI tool_calls: {[tc['id'] for tc in last_ai.tool_calls]}") - print(f" returning: {result}") - if result and "messages" in result: - for m in result["messages"]: - if isinstance(m, ToolMessage): - print(f" -> injecting ToolMessage(call_id={m.tool_call_id}, status={m.status})") - return result - checkpointer = MemorySaver() - with patch.object(ToolCallLimitMiddleware, "after_model", traced_after_model): - agent = create_agent( - model=model, - tools=[ask_fruit_expert, ask_veggie_expert], - middleware=[fruit_mw, veggie_mw], - checkpointer=checkpointer, - ) + agent = create_agent( + model=model, + tools=[ask_fruit_expert, ask_veggie_expert], + middleware=[fruit_mw, veggie_mw], + checkpointer=checkpointer, + ) - # Also trace the model_to_tools edge - # Get the compiled graph and print its structure - print("\n=== Graph nodes ===") - for name in agent.nodes: - print(f" {name}") - - config = {"configurable": {"thread_id": "debug1"}} - result = agent.invoke( - {"messages": [HumanMessage("Tell me about apples and bananas")]}, - config, - ) - - # v2 stream format: unwrap GraphOutput - assert isinstance(result, GraphOutput) - result_dict = result.value - interrupts = list(result.interrupts) - print(f"\n=== Results ===") - print(f"call_log: {call_log}") - print(f"interrupts: {len(interrupts)}") - - tool_msgs = [m for m in result_dict["messages"] if isinstance(m, ToolMessage)] - error_msgs = [m for m in tool_msgs if m.status == "error"] - print(f"error tool messages: {len(error_msgs)}") - for m in error_msgs: - print(f" call_id={m.tool_call_id} content={m.content!r}") + config = {"configurable": {"thread_id": "debug1"}} + result = agent.invoke( + {"messages": [HumanMessage("Tell me about apples and bananas")]}, + config, + ) assert len(call_log) == 1, f"Expected 1 execution, got {len(call_log)}: {call_log}" - assert len(interrupts) == 1, f"Expected 1 interrupt, got {len(interrupts)}" + assert len(result.interrupts) == 1, ( + f"Expected 1 interrupt, got {len(result.interrupts)}" + ) diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index ae39d8f4c3b..4f3707f40ed 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -1983,7 +1983,7 @@ requires-dist = [ { name = "langchain-perplexity", marker = "extra == 'perplexity'" }, { name = "langchain-together", marker = "extra == 'together'" }, { name = "langchain-xai", marker = "extra == 'xai'" }, - { name = "langgraph", git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1" }, + { name = "langgraph", git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=sr%2F1.1-v2-default" }, { name = "pydantic", specifier = ">=2.7.4,<3.0.0" }, ] provides-extras = ["community", "anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"] @@ -2536,7 +2536,7 @@ wheels = [ [[package]] name = "langgraph" version = "1.0.10rc1" -source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" } +source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" } dependencies = [ { name = "langchain-core" }, { name = "langgraph-checkpoint" }, @@ -2549,7 +2549,7 @@ dependencies = [ [[package]] name = "langgraph-checkpoint" version = "4.0.1rc4" -source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fcheckpoint&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" } +source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fcheckpoint&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" } dependencies = [ { name = "langchain-core" }, { name = "ormsgpack" }, @@ -2558,7 +2558,7 @@ dependencies = [ [[package]] name = "langgraph-prebuilt" version = "1.0.8" -source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fprebuilt&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" } +source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fprebuilt&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" } dependencies = [ { name = "langchain-core" }, { name = "langgraph-checkpoint" }, @@ -2567,7 +2567,7 @@ dependencies = [ [[package]] name = "langgraph-sdk" version = "0.3.9" -source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fsdk-py&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" } +source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fsdk-py&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" } dependencies = [ { name = "httpx" }, { name = "orjson" },