From ea3ec45441ea48afcf92cdf872deddacd65503aa Mon Sep 17 00:00:00 2001 From: Sydney Runkle Date: Mon, 27 Oct 2025 14:50:33 -0700 Subject: [PATCH] better tests --- .../unit_tests/agents/test_agent_runtime.py | 340 +++--------------- 1 file changed, 56 insertions(+), 284 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py b/libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py index 0a51530e58d..fd77cf63b39 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py @@ -1,22 +1,16 @@ -"""Tests for wrap_model_call and awrap_model_call functionality.""" - -from dataclasses import dataclass +"""Tests for AgentRuntime access via wrap_model_call middleware.""" import pytest from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.tools import tool from langchain.agents import create_agent from langchain.agents.middleware import wrap_model_call -from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain.agents.middleware.types import ModelRequest +from langchain.tools import ToolRuntime - -@dataclass -class Context: - """Test context for agent runtime.""" - - user_id: str - session_id: str +from .model import FakeToolCallingModel @pytest.fixture @@ -25,312 +19,90 @@ def fake_chat_model(): return GenericFakeChatModel(messages=iter([AIMessage(content="test response")])) -def test_wrap_model_call_basic(fake_chat_model): - """Test basic wrap_model_call functionality.""" - call_count = 0 +def test_agent_name_accessible_in_middleware(fake_chat_model): + """Test that agent name can be accessed via middleware.""" + captured_agent_name = None @wrap_model_call - def count_calls(request: ModelRequest, handler): - nonlocal call_count - call_count += 1 + def capture_agent_name(request: ModelRequest, handler): + nonlocal captured_agent_name + captured_agent_name = request.runtime.agent_name return handler(request) agent = create_agent( fake_chat_model, tools=[], - middleware=[count_calls], + middleware=[capture_agent_name], name="TestAgent", ) agent.invoke({"messages": [HumanMessage("Hello")]}) - assert call_count == 1 + assert captured_agent_name == "TestAgent" -def test_wrap_model_call_access_runtime(fake_chat_model): - """Test accessing AgentRuntime via ModelRequest in wrap_model_call.""" - captured_agent_name = None - captured_context = None +def test_nested_agent_name_accessible_in_tool(): + """Test that nested agent's name is accessible when agent is used in a tool.""" + # Track which agent names were captured + captured_agent_names = [] @wrap_model_call - def capture_from_request(request: ModelRequest, handler): - nonlocal captured_agent_name, captured_context - captured_agent_name = request.runtime.agent_name - captured_context = request.runtime.context + def capture_agent_name(request: ModelRequest, handler): + captured_agent_names.append(request.runtime.agent_name) return handler(request) - agent = create_agent( - fake_chat_model, + # Create a nested agent that will be called from within a tool + nested_agent = create_agent( + FakeToolCallingModel(), tools=[], - middleware=[capture_from_request], - name="RuntimeAgent", - context_schema=Context, + middleware=[capture_agent_name], + name="NestedAgent", ) - agent.invoke( - {"messages": [HumanMessage("Hello")]}, - context=Context(user_id="user123", session_id="session456"), + # Create a tool that invokes the nested agent + @tool + def call_nested_agent(query: str, runtime: ToolRuntime) -> str: + """Tool that calls a nested agent.""" + result = nested_agent.invoke({"messages": [HumanMessage(query)]}) + return result["messages"][-1].content + + # Create outer agent that uses the tool + outer_agent = create_agent( + FakeToolCallingModel( + tool_calls=[ + [{"name": "call_nested_agent", "args": {"query": "test"}, "id": "1"}], + [], + ] + ), + tools=[call_nested_agent], + middleware=[capture_agent_name], + name="OuterAgent", ) - assert captured_agent_name == "RuntimeAgent" - assert captured_context.user_id == "user123" - assert captured_context.session_id == "session456" + # Invoke the outer agent, which should call the tool, which calls the nested agent + outer_agent.invoke({"messages": [HumanMessage("Hello")]}) + + # Both agents should have captured their names + assert "OuterAgent" in captured_agent_names + assert "NestedAgent" in captured_agent_names -def test_wrap_model_call_modify_request(fake_chat_model): - """Test modifying the model request in wrap_model_call.""" - modified_messages = [] - - @wrap_model_call - def modify_request(request: ModelRequest, handler): - # Add a system prompt - modified_request = request.override(system_prompt="You are a helpful assistant") - modified_messages.append(modified_request.system_prompt) - return handler(modified_request) - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[modify_request], - name="ModifyAgent", - ) - - agent.invoke({"messages": [HumanMessage("Hello")]}) - assert modified_messages[0] == "You are a helpful assistant" - - -def test_wrap_model_call_modify_response(fake_chat_model): - """Test modifying the model response in wrap_model_call.""" - - @wrap_model_call - def modify_response(request: ModelRequest, handler): - response = handler(request) - # Modify the response content - original_msg = response.result[0] - modified_msg = AIMessage( - content=f"[MODIFIED] {original_msg.content}", - id=original_msg.id, - ) - return ModelResponse( - result=[modified_msg], - structured_response=response.structured_response, - ) - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[modify_response], - name="ModifyResponseAgent", - ) - - result = agent.invoke({"messages": [HumanMessage("Hello")]}) - assert result["messages"][-1].content == "[MODIFIED] test response" - - -def test_wrap_model_call_retry_logic(fake_chat_model): - """Test retry logic in wrap_model_call.""" - attempt_count = 0 - model_call_count = 0 - - @wrap_model_call - def retry_on_error(request: ModelRequest, handler): - nonlocal attempt_count, model_call_count - max_retries = 3 - last_error = None - - for attempt in range(max_retries): - attempt_count += 1 - try: - # Simulate failure on first two attempts - model_call_count += 1 - if model_call_count < 3: - raise ValueError("Simulated failure") - return handler(request) - except ValueError as e: - last_error = e - if attempt == max_retries - 1: - raise - - raise last_error # Should never reach here - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[retry_on_error], - name="RetryAgent", - ) - - result = agent.invoke({"messages": [HumanMessage("Hello")]}) - assert attempt_count == 3 - # The model response should be from fake_chat_model - assert result["messages"][-1].content == "test response" - - -def test_wrap_model_call_short_circuit(fake_chat_model): - """Test short-circuiting model call in wrap_model_call.""" - handler_called = False - - @wrap_model_call - def short_circuit(request: ModelRequest, handler): - nonlocal handler_called - # Check if we should short-circuit - if len(request.messages) > 0 and "bypass" in request.messages[-1].content: - # Return cached response without calling handler - return AIMessage(content="Cached response") - - handler_called = True - return handler(request) - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[short_circuit], - name="ShortCircuitAgent", - ) - - result = agent.invoke({"messages": [HumanMessage("bypass")]}) - assert not handler_called - assert result["messages"][-1].content == "Cached response" - - -def test_wrap_model_call_multiple_middleware(fake_chat_model): - """Test composing multiple wrap_model_call middleware.""" - execution_order = [] - - @wrap_model_call(name="first") - def first_middleware(request: ModelRequest, handler): - execution_order.append("first_before") - response = handler(request) - execution_order.append("first_after") - return response - - @wrap_model_call(name="second") - def second_middleware(request: ModelRequest, handler): - execution_order.append("second_before") - response = handler(request) - execution_order.append("second_after") - return response - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[first_middleware, second_middleware], - name="MultiWrapAgent", - ) - - agent.invoke({"messages": [HumanMessage("Hello")]}) - - # Middleware should compose as: first -> second -> model -> second -> first - assert execution_order == [ - "first_before", - "second_before", - "second_after", - "first_after", - ] - - -async def test_awrap_model_call_basic(fake_chat_model): - """Test basic awrap_model_call functionality.""" - call_count = 0 - - @wrap_model_call - async def count_calls_async(request: ModelRequest, handler): - nonlocal call_count - call_count += 1 - return await handler(request) - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[count_calls_async], - name="AsyncTestAgent", - ) - - await agent.ainvoke({"messages": [HumanMessage("Hello")]}) - assert call_count == 1 - - -async def test_awrap_model_call_access_runtime(fake_chat_model): - """Test accessing AgentRuntime in async wrap_model_call.""" +async def test_agent_name_accessible_in_async_middleware(): + """Test that agent name can be accessed in async middleware.""" captured_agent_name = None @wrap_model_call - async def capture_async(request: ModelRequest, handler): + async def capture_agent_name_async(request: ModelRequest, handler): nonlocal captured_agent_name captured_agent_name = request.runtime.agent_name return await handler(request) + fake_model = GenericFakeChatModel(messages=iter([AIMessage(content="async response")])) agent = create_agent( - fake_chat_model, + fake_model, tools=[], - middleware=[capture_async], - name="AsyncRuntimeAgent", + middleware=[capture_agent_name_async], + name="AsyncAgent", ) await agent.ainvoke({"messages": [HumanMessage("Hello")]}) - assert captured_agent_name == "AsyncRuntimeAgent" - - -async def test_awrap_model_call_retry_logic(fake_chat_model): - """Test async retry logic in awrap_model_call.""" - attempt_count = 0 - model_call_count = 0 - - @wrap_model_call - async def async_retry_on_error(request: ModelRequest, handler): - nonlocal attempt_count, model_call_count - max_retries = 3 - last_error = None - - for attempt in range(max_retries): - attempt_count += 1 - try: - # Simulate failure on first two attempts - model_call_count += 1 - if model_call_count < 3: - raise ValueError("Simulated async failure") - return await handler(request) - except ValueError as e: - last_error = e - if attempt == max_retries - 1: - raise - - raise last_error # Should never reach here - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[async_retry_on_error], - name="AsyncRetryAgent", - ) - - result = await agent.ainvoke({"messages": [HumanMessage("Hello")]}) - assert attempt_count == 3 - # The model response should be from fake_chat_model - assert result["messages"][-1].content == "test response" - - -async def test_awrap_model_call_modify_response(fake_chat_model): - """Test modifying response in async wrap_model_call.""" - - @wrap_model_call - async def async_modify_response(request: ModelRequest, handler): - response = await handler(request) - original_msg = response.result[0] - modified_msg = AIMessage( - content=f"[ASYNC MODIFIED] {original_msg.content}", - id=original_msg.id, - ) - return ModelResponse( - result=[modified_msg], - structured_response=response.structured_response, - ) - - agent = create_agent( - fake_chat_model, - tools=[], - middleware=[async_modify_response], - name="AsyncModifyAgent", - ) - - result = await agent.ainvoke({"messages": [HumanMessage("Hello")]}) - assert result["messages"][-1].content == "[ASYNC MODIFIED] test response" + assert captured_agent_name == "AsyncAgent"