mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
chore(langchain): fix types in test_model_call_limit_types (#34601)
This commit is contained in:
committed by
GitHub
parent
901690ceec
commit
5ae53fdfb3
@@ -2,11 +2,13 @@ import pytest
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from langchain.agents.factory import create_agent
|
from langchain.agents.factory import create_agent
|
||||||
from langchain.agents.middleware.model_call_limit import (
|
from langchain.agents.middleware.model_call_limit import (
|
||||||
ModelCallLimitExceededError,
|
ModelCallLimitExceededError,
|
||||||
ModelCallLimitMiddleware,
|
ModelCallLimitMiddleware,
|
||||||
|
ModelCallLimitState,
|
||||||
)
|
)
|
||||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||||
|
|
||||||
@@ -17,21 +19,20 @@ def simple_tool(value: str) -> str:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def test_middleware_unit_functionality():
|
def test_middleware_unit_functionality() -> None:
|
||||||
"""Test that the middleware works as expected in isolation."""
|
"""Test that the middleware works as expected in isolation."""
|
||||||
# Test with end behavior
|
# Test with end behavior
|
||||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
|
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1)
|
||||||
|
|
||||||
# Mock runtime (not used in current implementation)
|
runtime = Runtime()
|
||||||
runtime = None
|
|
||||||
|
|
||||||
# Test when limits are not exceeded
|
# Test when limits are not exceeded
|
||||||
state = {"thread_model_call_count": 0, "run_model_call_count": 0}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=0, run_model_call_count=0)
|
||||||
result = middleware.before_model(state, runtime)
|
result = middleware.before_model(state, runtime)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# Test when thread limit is exceeded
|
# Test when thread limit is exceeded
|
||||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0)
|
||||||
result = middleware.before_model(state, runtime)
|
result = middleware.before_model(state, runtime)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "end"
|
assert result["jump_to"] == "end"
|
||||||
@@ -40,7 +41,7 @@ def test_middleware_unit_functionality():
|
|||||||
assert "thread limit (2/2)" in result["messages"][0].content
|
assert "thread limit (2/2)" in result["messages"][0].content
|
||||||
|
|
||||||
# Test when run limit is exceeded
|
# Test when run limit is exceeded
|
||||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=1, run_model_call_count=1)
|
||||||
result = middleware.before_model(state, runtime)
|
result = middleware.before_model(state, runtime)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["jump_to"] == "end"
|
assert result["jump_to"] == "end"
|
||||||
@@ -54,21 +55,21 @@ def test_middleware_unit_functionality():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Test exception when thread limit exceeded
|
# Test exception when thread limit exceeded
|
||||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0)
|
||||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||||
middleware_exception.before_model(state, runtime)
|
middleware_exception.before_model(state, runtime)
|
||||||
|
|
||||||
assert "thread limit (2/2)" in str(exc_info.value)
|
assert "thread limit (2/2)" in str(exc_info.value)
|
||||||
|
|
||||||
# Test exception when run limit exceeded
|
# Test exception when run limit exceeded
|
||||||
state = {"thread_model_call_count": 1, "run_model_call_count": 1}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=1, run_model_call_count=1)
|
||||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||||
middleware_exception.before_model(state, runtime)
|
middleware_exception.before_model(state, runtime)
|
||||||
|
|
||||||
assert "run limit (1/1)" in str(exc_info.value)
|
assert "run limit (1/1)" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_thread_limit_with_create_agent():
|
def test_thread_limit_with_create_agent() -> None:
|
||||||
"""Test that thread limits work correctly with create_agent."""
|
"""Test that thread limits work correctly with create_agent."""
|
||||||
model = FakeToolCallingModel()
|
model = FakeToolCallingModel()
|
||||||
|
|
||||||
@@ -106,7 +107,7 @@ def test_thread_limit_with_create_agent():
|
|||||||
assert "thread limit" in result2["messages"][3].content
|
assert "thread limit" in result2["messages"][3].content
|
||||||
|
|
||||||
|
|
||||||
def test_run_limit_with_create_agent():
|
def test_run_limit_with_create_agent() -> None:
|
||||||
"""Test that run limits work correctly with create_agent."""
|
"""Test that run limits work correctly with create_agent."""
|
||||||
# Create a model that will make 2 calls
|
# Create a model that will make 2 calls
|
||||||
model = FakeToolCallingModel(
|
model = FakeToolCallingModel(
|
||||||
@@ -140,7 +141,7 @@ def test_run_limit_with_create_agent():
|
|||||||
assert "run limit" in result["messages"][3].content
|
assert "run limit" in result["messages"][3].content
|
||||||
|
|
||||||
|
|
||||||
def test_middleware_initialization_validation():
|
def test_middleware_initialization_validation() -> None:
|
||||||
"""Test that middleware initialization validates parameters correctly."""
|
"""Test that middleware initialization validates parameters correctly."""
|
||||||
# Test that at least one limit must be specified
|
# Test that at least one limit must be specified
|
||||||
with pytest.raises(ValueError, match="At least one limit must be specified"):
|
with pytest.raises(ValueError, match="At least one limit must be specified"):
|
||||||
@@ -148,7 +149,7 @@ def test_middleware_initialization_validation():
|
|||||||
|
|
||||||
# Test invalid exit behavior
|
# Test invalid exit behavior
|
||||||
with pytest.raises(ValueError, match="Invalid exit_behavior"):
|
with pytest.raises(ValueError, match="Invalid exit_behavior"):
|
||||||
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid")
|
ModelCallLimitMiddleware(thread_limit=5, exit_behavior="invalid") # type: ignore[arg-type]
|
||||||
|
|
||||||
# Test valid initialization
|
# Test valid initialization
|
||||||
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
|
middleware = ModelCallLimitMiddleware(thread_limit=5, run_limit=3)
|
||||||
@@ -167,32 +168,32 @@ def test_middleware_initialization_validation():
|
|||||||
assert middleware.run_limit == 3
|
assert middleware.run_limit == 3
|
||||||
|
|
||||||
|
|
||||||
def test_exception_error_message():
|
def test_exception_error_message() -> None:
|
||||||
"""Test that the exception provides clear error messages."""
|
"""Test that the exception provides clear error messages."""
|
||||||
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
|
middleware = ModelCallLimitMiddleware(thread_limit=2, run_limit=1, exit_behavior="error")
|
||||||
|
|
||||||
# Test thread limit exceeded
|
# Test thread limit exceeded
|
||||||
state = {"thread_model_call_count": 2, "run_model_call_count": 0}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=0)
|
||||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||||
middleware.before_model(state, None)
|
middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
error_msg = str(exc_info.value)
|
||||||
assert "Model call limits exceeded" in error_msg
|
assert "Model call limits exceeded" in error_msg
|
||||||
assert "thread limit (2/2)" in error_msg
|
assert "thread limit (2/2)" in error_msg
|
||||||
|
|
||||||
# Test run limit exceeded
|
# Test run limit exceeded
|
||||||
state = {"thread_model_call_count": 0, "run_model_call_count": 1}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=0, run_model_call_count=1)
|
||||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||||
middleware.before_model(state, None)
|
middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
error_msg = str(exc_info.value)
|
||||||
assert "Model call limits exceeded" in error_msg
|
assert "Model call limits exceeded" in error_msg
|
||||||
assert "run limit (1/1)" in error_msg
|
assert "run limit (1/1)" in error_msg
|
||||||
|
|
||||||
# Test both limits exceeded
|
# Test both limits exceeded
|
||||||
state = {"thread_model_call_count": 2, "run_model_call_count": 1}
|
state = ModelCallLimitState(messages=[], thread_model_call_count=2, run_model_call_count=1)
|
||||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||||
middleware.before_model(state, None)
|
middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
error_msg = str(exc_info.value)
|
error_msg = str(exc_info.value)
|
||||||
assert "Model call limits exceeded" in error_msg
|
assert "Model call limits exceeded" in error_msg
|
||||||
|
|||||||
Reference in New Issue
Block a user