diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py index 5c264ed0943..05ac31d6abe 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py @@ -2,11 +2,13 @@ import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.tools import tool from langgraph.checkpoint.memory import InMemorySaver +from langgraph.runtime import Runtime from langchain.agents.factory import create_agent from langchain.agents.middleware.model_call_limit import ( ModelCallLimitExceededError, ModelCallLimitMiddleware, + ModelCallLimitState, ) from tests.unit_tests.agents.model import FakeToolCallingModel @@ -17,21 +19,20 @@ def simple_tool(value: str) -> str: return value -def test_middleware_unit_functionality(): +def test_middleware_unit_functionality() -> None: """Test that the middleware works as expected in isolation.""" # Test with end behavior middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1) - # Mock runtime (not used in current implementation) - runtime = None + runtime = Runtime() # Test when limits are not exceeded - state = {"thread_model_call_count": 0, "run_model_call_count": 0} + state = ModelCallLimitState(messages=[], thread_model_call_count=0, run_model_call_count=0) result = middleware.before_model(state, runtime) assert result is None # Test when thread limit is exceeded - state = {"thread_model_call_count": 2, "run_model_call_count": 0} + state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0) result = middleware.before_model(state, runtime) assert result is not None assert result["jump_to"] == "end" @@ -40,7 +41,7 @@ def test_middleware_unit_functionality(): assert "thread limit (2/2)" in result["messages"][0].content # Test when run limit is exceeded - state = {"thread_model_call_count": 1, "run_model_call_count": 1} + state = ModelCallLimitState(messages=[], thread_model_call_count=1, run_model_call_count=1) result = middleware.before_model(state, runtime) assert result is not None assert result["jump_to"] == "end" @@ -54,21 +55,21 @@ def test_middleware_unit_functionality(): ) # Test exception when thread limit exceeded - state = {"thread_model_call_count": 2, "run_model_call_count": 0} + state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0) with pytest.raises(ModelCallLimitExceededError) as exc_info: middleware_exception.before_model(state, runtime) assert "thread limit (2/2)" in str(exc_info.value) # Test exception when run limit exceeded - state = {"thread_model_call_count": 1, "run_model_call_count": 1} + state = ModelCallLimitState(messages=[], thread_model_call_count=1, run_model_call_count=1) with pytest.raises(ModelCallLimitExceededError) as exc_info: middleware_exception.before_model(state, runtime) assert "run limit (1/1)" in str(exc_info.value) -def test_thread_limit_with_create_agent(): +def test_thread_limit_with_create_agent() -> None: """Test that thread limits work correctly with create_agent.""" model = FakeToolCallingModel() @@ -106,7 +107,7 @@ def test_thread_limit_with_create_agent(): assert "thread limit" in result2["messages"][3].content -def test_run_limit_with_create_agent(): +def test_run_limit_with_create_agent() -> None: """Test that run limits work correctly with create_agent.""" # Create a model that will make 2 calls model = FakeToolCallingModel( @@ -140,7 +141,7 @@ def test_run_limit_with_create_agent(): assert "run limit" in result["messages"][3].content -def test_middleware_initialization_validation(): +def test_middleware_initialization_validation() -> None: """Test that middleware initialization validates parameters correctly.""" # Test that at least one limit must be specified with pytest.raises(ValueError, match="At least one limit must be specified"): @@ -148,7 +149,7 @@ def test_middleware_initialization_validation(): # Test invalid exit behavior with pytest.raises(ValueError, match="Invalid exit_behavior"): - ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid") + ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid") # type: ignore[arg-type] # Test valid initialization middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3) @@ -167,32 +168,32 @@ def test_middleware_initialization_validation(): assert middleware.run_limit == 3 -def test_exception_error_message(): +def test_exception_error_message() -> None: """Test that the exception provides clear error messages.""" middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error") # Test thread limit exceeded - state = {"thread_model_call_count": 2, "run_model_call_count": 0} + state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0) with pytest.raises(ModelCallLimitExceededError) as exc_info: - middleware.before_model(state, None) + middleware.before_model(state, Runtime()) error_msg = str(exc_info.value) assert "Model call limits exceeded" in error_msg assert "thread limit (2/2)" in error_msg # Test run limit exceeded - state = {"thread_model_call_count": 0, "run_model_call_count": 1} + state = ModelCallLimitState(messages=[], thread_model_call_count=0, run_model_call_count=1) with pytest.raises(ModelCallLimitExceededError) as exc_info: - middleware.before_model(state, None) + middleware.before_model(state, Runtime()) error_msg = str(exc_info.value) assert "Model call limits exceeded" in error_msg assert "run limit (1/1)" in error_msg # Test both limits exceeded - state = {"thread_model_call_count": 2, "run_model_call_count": 1} + state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=1) with pytest.raises(ModelCallLimitExceededError) as exc_info: - middleware.before_model(state, None) + middleware.before_model(state, Runtime()) error_msg = str(exc_info.value) assert "Model call limits exceeded" in error_msg