mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 22:56:05 +00:00
better tests
This commit is contained in:
@@ -1,22 +1,16 @@
|
||||
"""Tests for wrap_model_call and awrap_model_call functionality."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
"""Tests for AgentRuntime access via wrap_model_call middleware."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import wrap_model_call
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
"""Test context for agent runtime."""
|
||||
|
||||
user_id: str
|
||||
session_id: str
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -25,312 +19,90 @@ def fake_chat_model():
|
||||
return GenericFakeChatModel(messages=iter([AIMessage(content="test response")]))
|
||||
|
||||
|
||||
def test_wrap_model_call_basic(fake_chat_model):
|
||||
"""Test basic wrap_model_call functionality."""
|
||||
call_count = 0
|
||||
def test_agent_name_accessible_in_middleware(fake_chat_model):
|
||||
"""Test that agent name can be accessed via middleware."""
|
||||
captured_agent_name = None
|
||||
|
||||
@wrap_model_call
|
||||
def count_calls(request: ModelRequest, handler):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
def capture_agent_name(request: ModelRequest, handler):
|
||||
nonlocal captured_agent_name
|
||||
captured_agent_name = request.runtime.agent_name
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[count_calls],
|
||||
middleware=[capture_agent_name],
|
||||
name="TestAgent",
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert call_count == 1
|
||||
assert captured_agent_name == "TestAgent"
|
||||
|
||||
|
||||
def test_wrap_model_call_access_runtime(fake_chat_model):
|
||||
"""Test accessing AgentRuntime via ModelRequest in wrap_model_call."""
|
||||
captured_agent_name = None
|
||||
captured_context = None
|
||||
def test_nested_agent_name_accessible_in_tool():
|
||||
"""Test that nested agent's name is accessible when agent is used in a tool."""
|
||||
# Track which agent names were captured
|
||||
captured_agent_names = []
|
||||
|
||||
@wrap_model_call
|
||||
def capture_from_request(request: ModelRequest, handler):
|
||||
nonlocal captured_agent_name, captured_context
|
||||
captured_agent_name = request.runtime.agent_name
|
||||
captured_context = request.runtime.context
|
||||
def capture_agent_name(request: ModelRequest, handler):
|
||||
captured_agent_names.append(request.runtime.agent_name)
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
# Create a nested agent that will be called from within a tool
|
||||
nested_agent = create_agent(
|
||||
FakeToolCallingModel(),
|
||||
tools=[],
|
||||
middleware=[capture_from_request],
|
||||
name="RuntimeAgent",
|
||||
context_schema=Context,
|
||||
middleware=[capture_agent_name],
|
||||
name="NestedAgent",
|
||||
)
|
||||
|
||||
agent.invoke(
|
||||
{"messages": [HumanMessage("Hello")]},
|
||||
context=Context(user_id="user123", session_id="session456"),
|
||||
# Create a tool that invokes the nested agent
|
||||
@tool
|
||||
def call_nested_agent(query: str, runtime: ToolRuntime) -> str:
|
||||
"""Tool that calls a nested agent."""
|
||||
result = nested_agent.invoke({"messages": [HumanMessage(query)]})
|
||||
return result["messages"][-1].content
|
||||
|
||||
# Create outer agent that uses the tool
|
||||
outer_agent = create_agent(
|
||||
FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"name": "call_nested_agent", "args": {"query": "test"}, "id": "1"}],
|
||||
[],
|
||||
]
|
||||
),
|
||||
tools=[call_nested_agent],
|
||||
middleware=[capture_agent_name],
|
||||
name="OuterAgent",
|
||||
)
|
||||
|
||||
assert captured_agent_name == "RuntimeAgent"
|
||||
assert captured_context.user_id == "user123"
|
||||
assert captured_context.session_id == "session456"
|
||||
# Invoke the outer agent, which should call the tool, which calls the nested agent
|
||||
outer_agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Both agents should have captured their names
|
||||
assert "OuterAgent" in captured_agent_names
|
||||
assert "NestedAgent" in captured_agent_names
|
||||
|
||||
|
||||
def test_wrap_model_call_modify_request(fake_chat_model):
|
||||
"""Test modifying the model request in wrap_model_call."""
|
||||
modified_messages = []
|
||||
|
||||
@wrap_model_call
|
||||
def modify_request(request: ModelRequest, handler):
|
||||
# Add a system prompt
|
||||
modified_request = request.override(system_prompt="You are a helpful assistant")
|
||||
modified_messages.append(modified_request.system_prompt)
|
||||
return handler(modified_request)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[modify_request],
|
||||
name="ModifyAgent",
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert modified_messages[0] == "You are a helpful assistant"
|
||||
|
||||
|
||||
def test_wrap_model_call_modify_response(fake_chat_model):
|
||||
"""Test modifying the model response in wrap_model_call."""
|
||||
|
||||
@wrap_model_call
|
||||
def modify_response(request: ModelRequest, handler):
|
||||
response = handler(request)
|
||||
# Modify the response content
|
||||
original_msg = response.result[0]
|
||||
modified_msg = AIMessage(
|
||||
content=f"[MODIFIED] {original_msg.content}",
|
||||
id=original_msg.id,
|
||||
)
|
||||
return ModelResponse(
|
||||
result=[modified_msg],
|
||||
structured_response=response.structured_response,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[modify_response],
|
||||
name="ModifyResponseAgent",
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["messages"][-1].content == "[MODIFIED] test response"
|
||||
|
||||
|
||||
def test_wrap_model_call_retry_logic(fake_chat_model):
|
||||
"""Test retry logic in wrap_model_call."""
|
||||
attempt_count = 0
|
||||
model_call_count = 0
|
||||
|
||||
@wrap_model_call
|
||||
def retry_on_error(request: ModelRequest, handler):
|
||||
nonlocal attempt_count, model_call_count
|
||||
max_retries = 3
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
attempt_count += 1
|
||||
try:
|
||||
# Simulate failure on first two attempts
|
||||
model_call_count += 1
|
||||
if model_call_count < 3:
|
||||
raise ValueError("Simulated failure")
|
||||
return handler(request)
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
|
||||
raise last_error # Should never reach here
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[retry_on_error],
|
||||
name="RetryAgent",
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert attempt_count == 3
|
||||
# The model response should be from fake_chat_model
|
||||
assert result["messages"][-1].content == "test response"
|
||||
|
||||
|
||||
def test_wrap_model_call_short_circuit(fake_chat_model):
|
||||
"""Test short-circuiting model call in wrap_model_call."""
|
||||
handler_called = False
|
||||
|
||||
@wrap_model_call
|
||||
def short_circuit(request: ModelRequest, handler):
|
||||
nonlocal handler_called
|
||||
# Check if we should short-circuit
|
||||
if len(request.messages) > 0 and "bypass" in request.messages[-1].content:
|
||||
# Return cached response without calling handler
|
||||
return AIMessage(content="Cached response")
|
||||
|
||||
handler_called = True
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[short_circuit],
|
||||
name="ShortCircuitAgent",
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("bypass")]})
|
||||
assert not handler_called
|
||||
assert result["messages"][-1].content == "Cached response"
|
||||
|
||||
|
||||
def test_wrap_model_call_multiple_middleware(fake_chat_model):
|
||||
"""Test composing multiple wrap_model_call middleware."""
|
||||
execution_order = []
|
||||
|
||||
@wrap_model_call(name="first")
|
||||
def first_middleware(request: ModelRequest, handler):
|
||||
execution_order.append("first_before")
|
||||
response = handler(request)
|
||||
execution_order.append("first_after")
|
||||
return response
|
||||
|
||||
@wrap_model_call(name="second")
|
||||
def second_middleware(request: ModelRequest, handler):
|
||||
execution_order.append("second_before")
|
||||
response = handler(request)
|
||||
execution_order.append("second_after")
|
||||
return response
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[first_middleware, second_middleware],
|
||||
name="MultiWrapAgent",
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Middleware should compose as: first -> second -> model -> second -> first
|
||||
assert execution_order == [
|
||||
"first_before",
|
||||
"second_before",
|
||||
"second_after",
|
||||
"first_after",
|
||||
]
|
||||
|
||||
|
||||
async def test_awrap_model_call_basic(fake_chat_model):
|
||||
"""Test basic awrap_model_call functionality."""
|
||||
call_count = 0
|
||||
|
||||
@wrap_model_call
|
||||
async def count_calls_async(request: ModelRequest, handler):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return await handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[count_calls_async],
|
||||
name="AsyncTestAgent",
|
||||
)
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
async def test_awrap_model_call_access_runtime(fake_chat_model):
|
||||
"""Test accessing AgentRuntime in async wrap_model_call."""
|
||||
async def test_agent_name_accessible_in_async_middleware():
|
||||
"""Test that agent name can be accessed in async middleware."""
|
||||
captured_agent_name = None
|
||||
|
||||
@wrap_model_call
|
||||
async def capture_async(request: ModelRequest, handler):
|
||||
async def capture_agent_name_async(request: ModelRequest, handler):
|
||||
nonlocal captured_agent_name
|
||||
captured_agent_name = request.runtime.agent_name
|
||||
return await handler(request)
|
||||
|
||||
fake_model = GenericFakeChatModel(messages=iter([AIMessage(content="async response")]))
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
fake_model,
|
||||
tools=[],
|
||||
middleware=[capture_async],
|
||||
name="AsyncRuntimeAgent",
|
||||
middleware=[capture_agent_name_async],
|
||||
name="AsyncAgent",
|
||||
)
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
assert captured_agent_name == "AsyncRuntimeAgent"
|
||||
|
||||
|
||||
async def test_awrap_model_call_retry_logic(fake_chat_model):
|
||||
"""Test async retry logic in awrap_model_call."""
|
||||
attempt_count = 0
|
||||
model_call_count = 0
|
||||
|
||||
@wrap_model_call
|
||||
async def async_retry_on_error(request: ModelRequest, handler):
|
||||
nonlocal attempt_count, model_call_count
|
||||
max_retries = 3
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
attempt_count += 1
|
||||
try:
|
||||
# Simulate failure on first two attempts
|
||||
model_call_count += 1
|
||||
if model_call_count < 3:
|
||||
raise ValueError("Simulated async failure")
|
||||
return await handler(request)
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
|
||||
raise last_error # Should never reach here
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[async_retry_on_error],
|
||||
name="AsyncRetryAgent",
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
assert attempt_count == 3
|
||||
# The model response should be from fake_chat_model
|
||||
assert result["messages"][-1].content == "test response"
|
||||
|
||||
|
||||
async def test_awrap_model_call_modify_response(fake_chat_model):
|
||||
"""Test modifying response in async wrap_model_call."""
|
||||
|
||||
@wrap_model_call
|
||||
async def async_modify_response(request: ModelRequest, handler):
|
||||
response = await handler(request)
|
||||
original_msg = response.result[0]
|
||||
modified_msg = AIMessage(
|
||||
content=f"[ASYNC MODIFIED] {original_msg.content}",
|
||||
id=original_msg.id,
|
||||
)
|
||||
return ModelResponse(
|
||||
result=[modified_msg],
|
||||
structured_response=response.structured_response,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[async_modify_response],
|
||||
name="AsyncModifyAgent",
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["messages"][-1].content == "[ASYNC MODIFIED] test response"
|
||||
assert captured_agent_name == "AsyncAgent"
|
||||
|
||||
Reference in New Issue
Block a user