mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-16 18:13:33 +00:00
boom
This commit is contained in:
@@ -9,7 +9,6 @@ from langchain_core.tools import InjectedToolCallId, tool
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import GraphOutput
|
||||
from pydantic import BaseModel, Field
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import override
|
||||
@@ -96,10 +95,7 @@ def test_create_agent_invoke(
|
||||
|
||||
thread1 = {"configurable": {"thread_id": "1"}}
|
||||
result = agent_one.invoke({"messages": ["hello"]}, thread1)
|
||||
# v2 stream format returns GraphOutput; unwrap to get the dict
|
||||
assert isinstance(result, GraphOutput)
|
||||
result_dict = result.value
|
||||
assert result_dict == {
|
||||
assert result.value == {
|
||||
"messages": [
|
||||
_AnyIdHumanMessage(content="hello"),
|
||||
AIMessage(
|
||||
@@ -201,9 +197,7 @@ def test_create_agent_jump(
|
||||
|
||||
thread1 = {"configurable": {"thread_id": "1"}}
|
||||
result = agent_one.invoke({"messages": []}, thread1)
|
||||
assert isinstance(result, GraphOutput)
|
||||
result_dict = result.value
|
||||
assert result_dict == {"messages": []}
|
||||
assert result.value == {"messages": []}
|
||||
assert calls == ["NoopSeven.before_model", "NoopEight.before_model"]
|
||||
|
||||
|
||||
@@ -701,9 +695,7 @@ async def test_create_agent_async_jump() -> None:
|
||||
|
||||
result = await agent.ainvoke({"messages": []})
|
||||
|
||||
assert isinstance(result, GraphOutput)
|
||||
result_dict = result.value
|
||||
assert result_dict == {"messages": []}
|
||||
assert result.value == {"messages": []}
|
||||
assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"]
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.errors import InvalidUpdateError
|
||||
from langgraph.types import Command, GraphOutput
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents import AgentState, create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
@@ -630,10 +630,8 @@ class TestComposition:
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert isinstance(result, GraphOutput)
|
||||
result_dict = result.value
|
||||
assert result_dict.get("structured_response") == {"key": "value"}
|
||||
messages = result_dict["messages"]
|
||||
assert result.value.get("structured_response") == {"key": "value"}
|
||||
messages = result.value["messages"]
|
||||
assert len(messages) == 2
|
||||
assert messages[1].content == "Hello"
|
||||
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
"""Debug: trace exactly what happens when middleware processes parallel calls.
|
||||
"""Test that middleware correctly handles parallel tool calls with limits.
|
||||
|
||||
Instruments after_model and model_to_tools to see if the ToolMessage injected
|
||||
by middleware is visible to the routing edge.
|
||||
Verifies that when middleware blocks some parallel tool calls, only the
|
||||
permitted calls execute and interrupts propagate correctly.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||
from langchain_core.messages import HumanMessage, ToolCall
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver, MemorySaver
|
||||
from langgraph.graph import MessagesState, StateGraph
|
||||
from langgraph.types import GraphOutput, interrupt
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
|
||||
@@ -47,61 +44,22 @@ def test_instrument_middleware_and_routing() -> None:
|
||||
fruit_mw = ToolCallLimitMiddleware(tool_name="ask_fruit_expert", run_limit=1)
|
||||
veggie_mw = ToolCallLimitMiddleware(tool_name="ask_veggie_expert", run_limit=1)
|
||||
|
||||
# Wrap after_model to trace its return value
|
||||
original_after_model = ToolCallLimitMiddleware.after_model
|
||||
|
||||
def traced_after_model(self, state, runtime):
|
||||
result = original_after_model(self, state, runtime)
|
||||
msgs = state.get("messages", [])
|
||||
ai_msgs = [m for m in msgs if isinstance(m, AIMessage)]
|
||||
tool_msgs = [m for m in msgs if isinstance(m, ToolMessage)]
|
||||
print(f"\n [{self.name}].after_model:")
|
||||
print(f" state has {len(msgs)} messages ({len(ai_msgs)} AI, {len(tool_msgs)} Tool)")
|
||||
if ai_msgs:
|
||||
last_ai = ai_msgs[-1]
|
||||
print(f" last AI tool_calls: {[tc['id'] for tc in last_ai.tool_calls]}")
|
||||
print(f" returning: {result}")
|
||||
if result and "messages" in result:
|
||||
for m in result["messages"]:
|
||||
if isinstance(m, ToolMessage):
|
||||
print(f" -> injecting ToolMessage(call_id={m.tool_call_id}, status={m.status})")
|
||||
return result
|
||||
|
||||
checkpointer = MemorySaver()
|
||||
|
||||
with patch.object(ToolCallLimitMiddleware, "after_model", traced_after_model):
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[ask_fruit_expert, ask_veggie_expert],
|
||||
middleware=[fruit_mw, veggie_mw],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[ask_fruit_expert, ask_veggie_expert],
|
||||
middleware=[fruit_mw, veggie_mw],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
# Also trace the model_to_tools edge
|
||||
# Get the compiled graph and print its structure
|
||||
print("\n=== Graph nodes ===")
|
||||
for name in agent.nodes:
|
||||
print(f" {name}")
|
||||
|
||||
config = {"configurable": {"thread_id": "debug1"}}
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Tell me about apples and bananas")]},
|
||||
config,
|
||||
)
|
||||
|
||||
# v2 stream format: unwrap GraphOutput
|
||||
assert isinstance(result, GraphOutput)
|
||||
result_dict = result.value
|
||||
interrupts = list(result.interrupts)
|
||||
print(f"\n=== Results ===")
|
||||
print(f"call_log: {call_log}")
|
||||
print(f"interrupts: {len(interrupts)}")
|
||||
|
||||
tool_msgs = [m for m in result_dict["messages"] if isinstance(m, ToolMessage)]
|
||||
error_msgs = [m for m in tool_msgs if m.status == "error"]
|
||||
print(f"error tool messages: {len(error_msgs)}")
|
||||
for m in error_msgs:
|
||||
print(f" call_id={m.tool_call_id} content={m.content!r}")
|
||||
config = {"configurable": {"thread_id": "debug1"}}
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("Tell me about apples and bananas")]},
|
||||
config,
|
||||
)
|
||||
|
||||
assert len(call_log) == 1, f"Expected 1 execution, got {len(call_log)}: {call_log}"
|
||||
assert len(interrupts) == 1, f"Expected 1 interrupt, got {len(interrupts)}"
|
||||
assert len(result.interrupts) == 1, (
|
||||
f"Expected 1 interrupt, got {len(result.interrupts)}"
|
||||
)
|
||||
|
||||
10
libs/langchain_v1/uv.lock
generated
10
libs/langchain_v1/uv.lock
generated
@@ -1983,7 +1983,7 @@ requires-dist = [
|
||||
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
|
||||
{ name = "langchain-together", marker = "extra == 'together'" },
|
||||
{ name = "langchain-xai", marker = "extra == 'xai'" },
|
||||
{ name = "langgraph", git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1" },
|
||||
{ name = "langgraph", git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=sr%2F1.1-v2-default" },
|
||||
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
|
||||
]
|
||||
provides-extras = ["community", "anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
|
||||
@@ -2536,7 +2536,7 @@ wheels = [
|
||||
[[package]]
|
||||
name = "langgraph"
|
||||
version = "1.0.10rc1"
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Flanggraph&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langgraph-checkpoint" },
|
||||
@@ -2549,7 +2549,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint"
|
||||
version = "4.0.1rc4"
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fcheckpoint&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fcheckpoint&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "ormsgpack" },
|
||||
@@ -2558,7 +2558,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "langgraph-prebuilt"
|
||||
version = "1.0.8"
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fprebuilt&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fprebuilt&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langgraph-checkpoint" },
|
||||
@@ -2567,7 +2567,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "langgraph-sdk"
|
||||
version = "0.3.9"
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fsdk-py&branch=1.1#dec15d13dc47e032886816a8cbd16e4e99d6d45a" }
|
||||
source = { git = "https://github.com/langchain-ai/langgraph.git?subdirectory=libs%2Fsdk-py&branch=sr%2F1.1-v2-default#33289349a938ee0d6a8d3d5f01db989d43fc28d7" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
{ name = "orjson" },
|
||||
|
||||
Reference in New Issue
Block a user