poc/diff-for-v2

This commit is contained in:
Sydney Runkle
2026-03-03 14:16:23 -08:00
parent 6b37ad43dd
commit 1587713786
6 changed files with 154 additions and 38 deletions

View File

@@ -25,7 +25,7 @@ version = "1.2.10"
requires-python = ">=3.10.0,<4.0.0"
dependencies = [
"langchain-core>=1.2.10,<2.0.0",
"langgraph>=1.0.8,<1.1.0",
"langgraph>=1.0.8,<2.0.0",
"pydantic>=2.7.4,<3.0.0",
]
@@ -100,6 +100,7 @@ langchain-tests = { path = "../standard-tests", editable = true }
langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }
langchain-anthropic = { path = "../partners/anthropic", editable = true }
langgraph = { git = "https://github.com/langchain-ai/langgraph.git", branch = "1.1", subdirectory = "libs/langgraph" }
[tool.ruff]
line-length = 100

View File

@@ -9,6 +9,7 @@ from langchain_core.tools import InjectedToolCallId, tool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime
from langgraph.types import GraphOutput
from pydantic import BaseModel, Field
from syrupy.assertion import SnapshotAssertion
from typing_extensions import override
@@ -94,7 +95,11 @@ def test_create_agent_invoke(
)
thread1 = {"configurable": {"thread_id": "1"}}
assert agent_one.invoke({"messages": ["hello"]}, thread1) == {
result = agent_one.invoke({"messages": ["hello"]}, thread1)
# v2 stream format returns GraphOutput; unwrap to get the dict
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict == {
"messages": [
_AnyIdHumanMessage(content="hello"),
AIMessage(
@@ -195,7 +200,10 @@ def test_create_agent_jump(
assert agent_one.get_graph().draw_mermaid() == snapshot
thread1 = {"configurable": {"thread_id": "1"}}
assert agent_one.invoke({"messages": []}, thread1) == {"messages": []}
result = agent_one.invoke({"messages": []}, thread1)
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict == {"messages": []}
assert calls == ["NoopSeven.before_model", "NoopEight.before_model"]
@@ -693,7 +701,9 @@ async def test_create_agent_async_jump() -> None:
result = await agent.ainvoke({"messages": []})
assert result == {"messages": []}
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict == {"messages": []}
assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"]

View File

@@ -11,7 +11,7 @@ import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.errors import InvalidUpdateError
from langgraph.types import Command
from langgraph.types import Command, GraphOutput
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware.types import (
@@ -630,8 +630,10 @@ class TestComposition:
result = agent.invoke({"messages": [HumanMessage("Hi")]})
assert result.get("structured_response") == {"key": "value"}
messages = result["messages"]
assert isinstance(result, GraphOutput)
result_dict = result.value
assert result_dict.get("structured_response") == {"key": "value"}
messages = result_dict["messages"]
assert len(messages) == 2
assert messages[1].content == "Hello"

View File

@@ -0,0 +1,107 @@
"""Debug: trace exactly what happens when middleware processes parallel calls.
Instruments after_model and model_to_tools to see if the ToolMessage injected
by middleware is visible to the routing edge.
"""
from unittest.mock import patch
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver, MemorySaver
from langgraph.graph import MessagesState, StateGraph
from langgraph.types import GraphOutput, interrupt
from langchain.agents.factory import create_agent
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from tests.unit_tests.agents.model import FakeToolCallingModel
def test_instrument_middleware_and_routing() -> None:
"""Trace the middleware return value and what model_to_tools sees."""
call_log: list[str] = []
@tool
def ask_fruit_expert(question: str) -> str:
"""Ask the fruit expert."""
call_log.append(f"fruit:{question}")
interrupt("continue?")
return f"Fruit answer: {question}"
@tool
def ask_veggie_expert(question: str) -> str:
"""Ask the veggie expert."""
call_log.append(f"veggie:{question}")
return f"Veggie answer: {question}"
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="ask_fruit_expert", args={"question": "apples"}, id="c1"),
ToolCall(name="ask_fruit_expert", args={"question": "bananas"}, id="c2"),
],
[],
]
)
fruit_mw = ToolCallLimitMiddleware(tool_name="ask_fruit_expert", run_limit=1)
veggie_mw = ToolCallLimitMiddleware(tool_name="ask_veggie_expert", run_limit=1)
# Wrap after_model to trace its return value
original_after_model = ToolCallLimitMiddleware.after_model
def traced_after_model(self, state, runtime):
result = original_after_model(self, state, runtime)
msgs = state.get("messages", [])
ai_msgs = [m for m in msgs if isinstance(m, AIMessage)]
tool_msgs = [m for m in msgs if isinstance(m, ToolMessage)]
print(f"\n [{self.name}].after_model:")
print(f" state has {len(msgs)} messages ({len(ai_msgs)} AI, {len(tool_msgs)} Tool)")
if ai_msgs:
last_ai = ai_msgs[-1]
print(f" last AI tool_calls: {[tc['id'] for tc in last_ai.tool_calls]}")
print(f" returning: {result}")
if result and "messages" in result:
for m in result["messages"]:
if isinstance(m, ToolMessage):
print(f" -> injecting ToolMessage(call_id={m.tool_call_id}, status={m.status})")
return result
checkpointer = MemorySaver()
with patch.object(ToolCallLimitMiddleware, "after_model", traced_after_model):
agent = create_agent(
model=model,
tools=[ask_fruit_expert, ask_veggie_expert],
middleware=[fruit_mw, veggie_mw],
checkpointer=checkpointer,
)
# Also trace the model_to_tools edge
# Get the compiled graph and print its structure
print("\n=== Graph nodes ===")
for name in agent.nodes:
print(f" {name}")
config = {"configurable": {"thread_id": "debug1"}}
result = agent.invoke(
{"messages": [HumanMessage("Tell me about apples and bananas")]},
config,
)
# v2 stream format: unwrap GraphOutput
assert isinstance(result, GraphOutput)
result_dict = result.value
interrupts = list(result.interrupts)
print(f"\n=== Results ===")
print(f"call_log: {call_log}")
print(f"interrupts: {len(interrupts)}")
tool_msgs = [m for m in result_dict["messages"] if isinstance(m, ToolMessage)]
error_msgs = [m for m in tool_msgs if m.status == "error"]
print(f"error tool messages: {len(error_msgs)}")
for m in error_msgs:
print(f" call_id={m.tool_call_id} content={m.content!r}")
assert len(call_log) == 1, f"Expected 1 execution, got {len(call_log)}: {call_log}"
assert len(interrupts) == 1, f"Expected 1 interrupt, got {len(interrupts)}"

View File

@@ -114,10 +114,12 @@ def test_lc_agent_name_in_stream_metadata() -> None:
)
metadata_with_agent_name = []
for _chunk, metadata in agent.stream(
for event in agent.stream(
{"messages": [HumanMessage("Hello")]},
stream_mode="messages",
):
assert event["type"] == "messages"
_chunk, metadata = event["data"]
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])
@@ -132,10 +134,12 @@ def test_lc_agent_name_not_in_stream_metadata_when_name_not_provided() -> None:
model=FakeToolCallingModel(tool_calls=tool_calls),
)
for _chunk, metadata in agent.stream(
for event in agent.stream(
{"messages": [HumanMessage("Hello")]},
stream_mode="messages",
):
assert event["type"] == "messages"
_chunk, metadata = event["data"]
assert "lc_agent_name" not in metadata
@@ -150,10 +154,12 @@ def test_lc_agent_name_in_stream_metadata_multiple_iterations() -> None:
)
metadata_with_agent_name = []
for _chunk, metadata in agent.stream(
for event in agent.stream(
{"messages": [HumanMessage("Call a tool")]},
stream_mode="messages",
):
assert event["type"] == "messages"
_chunk, metadata = event["data"]
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])
@@ -171,10 +177,12 @@ async def test_lc_agent_name_in_astream_metadata() -> None:
)
metadata_with_agent_name = []
async for _chunk, metadata in agent.astream(
async for event in agent.astream(
{"messages": [HumanMessage("Hello async")]},
stream_mode="messages",
):
assert event["type"] == "messages"
_chunk, metadata = event["data"]
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])
@@ -189,10 +197,12 @@ async def test_lc_agent_name_not_in_astream_metadata_when_name_not_provided() ->
model=FakeToolCallingModel(tool_calls=tool_calls),
)
async for _chunk, metadata in agent.astream(
async for event in agent.astream(
{"messages": [HumanMessage("Hello async")]},
stream_mode="messages",
):
assert event["type"] == "messages"
_chunk, metadata = event["data"]
assert "lc_agent_name" not in metadata
@@ -207,10 +217,12 @@ async def test_lc_agent_name_in_astream_metadata_multiple_iterations() -> None:
)
metadata_with_agent_name = []
async for _chunk, metadata in agent.astream(
async for event in agent.astream(
{"messages": [HumanMessage("Call tool async")]},
stream_mode="messages",
):
assert event["type"] == "messages"
_chunk, metadata = event["data"]
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])

View File

@@ -1983,7 +1983,7 @@ requires-dist = [
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
{ name = "langchain-together", marker = "extra == 'together'" },
{ name = "langchain-xai", marker = "extra == 'xai'" },
{ name = "langgraph", specifier = ">=1.0.8,<1.1.0" },
{ name = "langgraph", git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1" },
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
]
provides-extras = ["community", "anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
@@ -2151,7 +2151,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.2.16"
version = "1.2.17"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },
@@ -2535,8 +2535,8 @@ wheels = [
[[package]]
name = "langgraph"
version = "1.0.10"
source = { registry = "https://pypi.org/simple" }
version = "1.0.10rc1"
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
@@ -2545,49 +2545,33 @@ dependencies = [
{ name = "pydantic" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/55/92/14df6fefba28c10caf1cb05aa5b8c7bf005838fe32a86d903b6c7cc4018d/langgraph-1.0.10.tar.gz", hash = "sha256:73bd10ee14a8020f31ef07e9cd4c1a70c35cc07b9c2b9cd637509a10d9d51e29", size = 511644, upload-time = "2026-02-27T21:04:38.743Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5d/60/260e0c04620a37ba8916b712766c341cc5fc685dabc6948c899494bbc2ae/langgraph-1.0.10-py3-none-any.whl", hash = "sha256:7c298bef4f6ea292fcf9824d6088fe41a6727e2904ad6066f240c4095af12247", size = 160920, upload-time = "2026-02-27T21:04:35.932Z" },
]
[[package]]
name = "langgraph-checkpoint"
version = "4.0.0"
source = { registry = "https://pypi.org/simple" }
version = "4.0.1rc4"
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fcheckpoint&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
dependencies = [
{ name = "langchain-core" },
{ name = "ormsgpack" },
]
sdist = { url = "https://files.pythonhosted.org/packages/98/76/55a18c59dedf39688d72c4b06af73a5e3ea0d1a01bc867b88fbf0659f203/langgraph_checkpoint-4.0.0.tar.gz", hash = "sha256:814d1bd050fac029476558d8e68d87bce9009a0262d04a2c14b918255954a624", size = 137320, upload-time = "2026-01-12T20:30:26.38Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4a/de/ddd53b7032e623f3c7bcdab2b44e8bf635e468f62e10e5ff1946f62c9356/langgraph_checkpoint-4.0.0-py3-none-any.whl", hash = "sha256:3fa9b2635a7c5ac28b338f631abf6a030c3b508b7b9ce17c22611513b589c784", size = 46329, upload-time = "2026-01-12T20:30:25.2Z" },
]
[[package]]
name = "langgraph-prebuilt"
version = "1.0.8"
source = { registry = "https://pypi.org/simple" }
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fprebuilt&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0d/06/dd61a5c2dce009d1b03b1d56f2a85b3127659fdddf5b3be5d8f1d60820fb/langgraph_prebuilt-1.0.8.tar.gz", hash = "sha256:0cd3cf5473ced8a6cd687cc5294e08d3de57529d8dd14fdc6ae4899549efcf69", size = 164442, upload-time = "2026-02-19T18:14:39.083Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/dc/41/ec966424ad3f2ed3996d24079d3342c8cd6c0bd0653c12b2a917a685ec6c/langgraph_prebuilt-1.0.8-py3-none-any.whl", hash = "sha256:d16a731e591ba4470f3e313a319c7eee7dbc40895bcf15c821f985a3522a7ce0", size = 35648, upload-time = "2026-02-19T18:14:37.611Z" },
]
[[package]]
name = "langgraph-sdk"
version = "0.3.3"
source = { registry = "https://pypi.org/simple" }
version = "0.3.9"
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fsdk-py&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
dependencies = [
{ name = "httpx" },
{ name = "orjson" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c3/0f/ed0634c222eed48a31ba48eab6881f94ad690d65e44fe7ca838240a260c1/langgraph_sdk-0.3.3.tar.gz", hash = "sha256:c34c3dce3b6848755eb61f0c94369d1ba04aceeb1b76015db1ea7362c544fb26", size = 130589, upload-time = "2026-01-13T00:30:43.894Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6e/be/4ad511bacfdd854afb12974f407cb30010dceb982dc20c55491867b34526/langgraph_sdk-0.3.3-py3-none-any.whl", hash = "sha256:a52ebaf09d91143e55378bb2d0b033ed98f57f48c9ad35c8f81493b88705fc7b", size = 67021, upload-time = "2026-01-13T00:30:42.264Z" },
]
[[package]]
name = "langsmith"