chore(langchain): fix types in test_model_call_limit_types (#34601)

This commit is contained in:
Christophe Bornet
2026-01-05 20:37:03 +01:00
committed by GitHub
parent 901690ceec
commit 5ae53fdfb3

View File

@@ -2,11 +2,13 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime
from langchain.agents.factory import create_agent from langchain.agents.factory import create_agent
from langchain.agents.middleware.model_call_limit import ( from langchain.agents.middleware.model_call_limit import (
ModelCallLimitExceededError, ModelCallLimitExceededError,
ModelCallLimitMiddleware, ModelCallLimitMiddleware,
ModelCallLimitState,
) )
from tests.unit_tests.agents.model import FakeToolCallingModel from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -17,21 +19,20 @@ def simple_tool(value: str) -> str:
return value return value
def test_middleware_unit_functionality(): def test_middleware_unit_functionality() -> None:
"""Test that the middleware works as expected in isolation.""" """Test that the middleware works as expected in isolation."""
# Test with end behavior # Test with end behavior
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1) middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
# Mock runtime (not used in current implementation) runtime = Runtime()
runtime = None
# Test when limits are not exceeded # Test when limits are not exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 0} state = ModelCallLimitState(messages=[], thread_model_call_count=0, run_model_call_count=0)
result = middleware.before_model(state, runtime) result = middleware.before_model(state, runtime)
assert result is None assert result is None
# Test when thread limit is exceeded # Test when thread limit is exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0} state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0)
result = middleware.before_model(state, runtime) result = middleware.before_model(state, runtime)
assert result is not None assert result is not None
assert result["jump_to"] == "end" assert result["jump_to"] == "end"
@@ -40,7 +41,7 @@ def test_middleware_unit_functionality():
assert "thread limit (2/2)" in result["messages"][0].content assert "thread limit (2/2)" in result["messages"][0].content
# Test when run limit is exceeded # Test when run limit is exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1} state = ModelCallLimitState(messages=[], thread_model_call_count=1, run_model_call_count=1)
result = middleware.before_model(state, runtime) result = middleware.before_model(state, runtime)
assert result is not None assert result is not None
assert result["jump_to"] == "end" assert result["jump_to"] == "end"
@@ -54,21 +55,21 @@ def test_middleware_unit_functionality():
) )
# Test exception when thread limit exceeded # Test exception when thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0} state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0)
with pytest.raises(ModelCallLimitExceededError) as exc_info: with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime) middleware_exception.before_model(state, runtime)
assert "thread limit (2/2)" in str(exc_info.value) assert "thread limit (2/2)" in str(exc_info.value)
# Test exception when run limit exceeded # Test exception when run limit exceeded
state = {"thread_model_call_count": 1, "run_model_call_count": 1} state = ModelCallLimitState(messages=[], thread_model_call_count=1, run_model_call_count=1)
with pytest.raises(ModelCallLimitExceededError) as exc_info: with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware_exception.before_model(state, runtime) middleware_exception.before_model(state, runtime)
assert "run limit (1/1)" in str(exc_info.value) assert "run limit (1/1)" in str(exc_info.value)
def test_thread_limit_with_create_agent(): def test_thread_limit_with_create_agent() -> None:
"""Test that thread limits work correctly with create_agent.""" """Test that thread limits work correctly with create_agent."""
model = FakeToolCallingModel() model = FakeToolCallingModel()
@@ -106,7 +107,7 @@ def test_thread_limit_with_create_agent():
assert "thread limit" in result2["messages"][3].content assert "thread limit" in result2["messages"][3].content
def test_run_limit_with_create_agent(): def test_run_limit_with_create_agent() -> None:
"""Test that run limits work correctly with create_agent.""" """Test that run limits work correctly with create_agent."""
# Create a model that will make 2 calls # Create a model that will make 2 calls
model = FakeToolCallingModel( model = FakeToolCallingModel(
@@ -140,7 +141,7 @@ def test_run_limit_with_create_agent():
assert "run limit" in result["messages"][3].content assert "run limit" in result["messages"][3].content
def test_middleware_initialization_validation(): def test_middleware_initialization_validation() -> None:
"""Test that middleware initialization validates parameters correctly.""" """Test that middleware initialization validates parameters correctly."""
# Test that at least one limit must be specified # Test that at least one limit must be specified
with pytest.raises(ValueError, match="At least one limit must be specified"): with pytest.raises(ValueError, match="At least one limit must be specified"):
@@ -148,7 +149,7 @@ def test_middleware_initialization_validation():
# Test invalid exit behavior # Test invalid exit behavior
with pytest.raises(ValueError, match="Invalid exit_behavior"): with pytest.raises(ValueError, match="Invalid exit_behavior"):
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid") ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid") # type: ignore[arg-type]
# Test valid initialization # Test valid initialization
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3) middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
@@ -167,32 +168,32 @@ def test_middleware_initialization_validation():
assert middleware.run_limit == 3 assert middleware.run_limit == 3
def test_exception_error_message(): def test_exception_error_message() -> None:
"""Test that the exception provides clear error messages.""" """Test that the exception provides clear error messages."""
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error") middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
# Test thread limit exceeded # Test thread limit exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 0} state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0)
with pytest.raises(ModelCallLimitExceededError) as exc_info: with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None) middleware.before_model(state, Runtime())
error_msg = str(exc_info.value) error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg assert "Model call limits exceeded" in error_msg
assert "thread limit (2/2)" in error_msg assert "thread limit (2/2)" in error_msg
# Test run limit exceeded # Test run limit exceeded
state = {"thread_model_call_count": 0, "run_model_call_count": 1} state = ModelCallLimitState(messages=[], thread_model_call_count=0, run_model_call_count=1)
with pytest.raises(ModelCallLimitExceededError) as exc_info: with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None) middleware.before_model(state, Runtime())
error_msg = str(exc_info.value) error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg assert "Model call limits exceeded" in error_msg
assert "run limit (1/1)" in error_msg assert "run limit (1/1)" in error_msg
# Test both limits exceeded # Test both limits exceeded
state = {"thread_model_call_count": 2, "run_model_call_count": 1} state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=1)
with pytest.raises(ModelCallLimitExceededError) as exc_info: with pytest.raises(ModelCallLimitExceededError) as exc_info:
middleware.before_model(state, None) middleware.before_model(state, Runtime())
error_msg = str(exc_info.value) error_msg = str(exc_info.value)
assert "Model call limits exceeded" in error_msg assert "Model call limits exceeded" in error_msg