diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py index ada493e1cb5..24b933c7997 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py @@ -1,19 +1,19 @@ """Unit tests for _chain_model_call_handlers handler composition.""" -from typing import TYPE_CHECKING, cast +from collections.abc import Callable +from typing import Any, TypedDict, cast from langchain_core.messages import AIMessage +from langgraph.runtime import Runtime +from langchain.agents import AgentState from langchain.agents.factory import _chain_model_call_handlers from langchain.agents.middleware.types import ModelRequest, ModelResponse -if TYPE_CHECKING: - from langgraph.runtime import Runtime - -def create_test_request(**kwargs): +def create_test_request(**kwargs: Any) -> ModelRequest: """Helper to create a `ModelRequest` with sensible defaults.""" - defaults = { + defaults: dict[str, Any] = { "messages": [], "model": None, "system_prompt": None, @@ -27,10 +27,10 @@ def create_test_request(**kwargs): return ModelRequest(**defaults) -def create_mock_base_handler(content="test"): +def create_mock_base_handler(content: str = "test") -> Callable[[ModelRequest], ModelResponse]: """Helper to create a base handler that returns `ModelResponse`.""" - def mock_base_handler(req): + def mock_base_handler(req: ModelRequest) -> ModelResponse: return ModelResponse(result=[AIMessage(content=content)], structured_response=None) return mock_base_handler @@ -47,7 +47,9 @@ class TestChainModelCallHandlers: def test_single_handler_returns_unchanged(self) -> None: """Test that single handler is wrapped to normalize output.""" - def handler(request, base_handler): + def handler( + request: ModelRequest, base_handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: return base_handler(request) result = _chain_model_call_handlers([handler]) @@ -59,13 +61,17 @@ class TestChainModelCallHandlers: """Test basic composition of two handlers.""" execution_order = [] - def outer(request, handler): + def outer( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: execution_order.append("outer-before") result = handler(request) execution_order.append("outer-after") return result - def inner(request, handler): + def inner( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: execution_order.append("inner-before") result = handler(request) execution_order.append("inner-after") @@ -90,19 +96,25 @@ class TestChainModelCallHandlers: """Test composition of three handlers.""" execution_order = [] - def first(request, handler): + def first( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: execution_order.append("first-before") result = handler(request) execution_order.append("first-after") return result - def second(request, handler): + def second( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: execution_order.append("second-before") result = handler(request) execution_order.append("second-after") return result - def third(request, handler): + def third( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: execution_order.append("third-before") result = handler(request) execution_order.append("third-after") @@ -129,10 +141,14 @@ class TestChainModelCallHandlers: """Test inner handler retrying before outer sees response.""" inner_attempts = [] - def outer_passthrough(request, handler): + def outer_passthrough( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: return handler(request) - def inner_with_retry(request, handler): + def inner_with_retry( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse | AIMessage: for attempt in range(3): inner_attempts.append(attempt) try: @@ -147,7 +163,7 @@ class TestChainModelCallHandlers: call_count = {"value": 0} - def mock_base_handler(req): + def mock_base_handler(req: ModelRequest) -> ModelResponse: call_count["value"] += 1 if call_count["value"] < 3: msg = "fail" @@ -163,20 +179,24 @@ class TestChainModelCallHandlers: def test_error_to_success_conversion(self) -> None: """Test handler converting error to success response.""" - def outer_error_handler(request, handler): + def outer_error_handler( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse | AIMessage: try: return handler(request) except Exception: # Middleware can return AIMessage - it will be normalized to ModelResponse return AIMessage(content="Fallback response") - def inner_passthrough(request, handler): + def inner_passthrough( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: return handler(request) composed = _chain_model_call_handlers([outer_error_handler, inner_passthrough]) assert composed is not None - def mock_base_handler(req): + def mock_base_handler(req: ModelRequest) -> ModelResponse: msg = "Model failed" raise ValueError(msg) @@ -191,13 +211,17 @@ class TestChainModelCallHandlers: """Test handlers modifying the request.""" requests_seen = [] - def outer_add_context(request, handler): + def outer_add_context( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: modified_request = create_test_request( messages=[*request.messages], system_prompt="Added by outer" ) return handler(modified_request) - def inner_track_request(request, handler): + def inner_track_request( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: requests_seen.append(request.system_prompt) return handler(request) @@ -212,15 +236,26 @@ class TestChainModelCallHandlers: def test_composition_preserves_state_and_runtime(self) -> None: """Test that state and runtime are passed through composition.""" + + class CustomState(AgentState[Any]): + test: str + + class CustomContext(TypedDict): + test: str + state_values = [] runtime_values = [] - def outer(request, handler): + def outer( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: state_values.append(("outer", request.state)) runtime_values.append(("outer", request.runtime)) return handler(request) - def inner(request, handler): + def inner( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: state_values.append(("inner", request.state)) runtime_values.append(("inner", request.runtime)) return handler(request) @@ -228,8 +263,8 @@ class TestChainModelCallHandlers: composed = _chain_model_call_handlers([outer, inner]) assert composed is not None - test_state = {"test": "state"} - test_runtime = {"test": "runtime"} + test_state = CustomState(messages=[], test="state") + test_runtime = Runtime(context=CustomContext(test="runtime")) # Create request with state and runtime test_request = create_test_request(state=test_state, runtime=test_runtime) @@ -245,11 +280,15 @@ class TestChainModelCallHandlers: """Test handler that retries multiple times.""" call_count = {"value": 0} - def outer_counts_calls(request, handler): + def outer_counts_calls( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: call_count["value"] += 1 return handler(request) - def inner_retries(request, handler): + def inner_retries( + request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: try: return handler(request) except ValueError: @@ -261,7 +300,7 @@ class TestChainModelCallHandlers: attempt = {"value": 0} - def mock_base_handler(req): + def mock_base_handler(req: ModelRequest) -> ModelResponse: attempt["value"] += 1 if attempt["value"] == 1: msg = "fail"