chore(langchain): fix types in test_composition (#34580)

This commit is contained in:
Christophe Bornet
2026-01-05 20:49:34 +01:00
committed by GitHub
parent 3b65985551
commit 7979fd3d9f

View File

@@ -1,19 +1,19 @@
"""Unit tests for _chain_model_call_handlers handler composition."""
from typing import TYPE_CHECKING, cast
from collections.abc import Callable
from typing import Any, TypedDict, cast
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
from langchain.agents import AgentState
from langchain.agents.factory import _chain_model_call_handlers
from langchain.agents.middleware.types import ModelRequest, ModelResponse
if TYPE_CHECKING:
from langgraph.runtime import Runtime
def create_test_request(**kwargs):
def create_test_request(**kwargs: Any) -> ModelRequest:
"""Helper to create a `ModelRequest` with sensible defaults."""
defaults = {
defaults: dict[str, Any] = {
"messages": [],
"model": None,
"system_prompt": None,
@@ -27,10 +27,10 @@ def create_test_request(**kwargs):
return ModelRequest(**defaults)
def create_mock_base_handler(content="test"):
def create_mock_base_handler(content: str = "test") -> Callable[[ModelRequest], ModelResponse]:
"""Helper to create a base handler that returns `ModelResponse`."""
def mock_base_handler(req):
def mock_base_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content=content)], structured_response=None)
return mock_base_handler
@@ -47,7 +47,9 @@ class TestChainModelCallHandlers:
def test_single_handler_returns_unchanged(self) -> None:
"""Test that single handler is wrapped to normalize output."""
def handler(request, base_handler):
def handler(
request: ModelRequest, base_handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
return base_handler(request)
result = _chain_model_call_handlers([handler])
@@ -59,13 +61,17 @@ class TestChainModelCallHandlers:
"""Test basic composition of two handlers."""
execution_order = []
def outer(request, handler):
def outer(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
execution_order.append("outer-before")
result = handler(request)
execution_order.append("outer-after")
return result
def inner(request, handler):
def inner(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
execution_order.append("inner-before")
result = handler(request)
execution_order.append("inner-after")
@@ -90,19 +96,25 @@ class TestChainModelCallHandlers:
"""Test composition of three handlers."""
execution_order = []
def first(request, handler):
def first(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
execution_order.append("first-before")
result = handler(request)
execution_order.append("first-after")
return result
def second(request, handler):
def second(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
execution_order.append("second-before")
result = handler(request)
execution_order.append("second-after")
return result
def third(request, handler):
def third(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
execution_order.append("third-before")
result = handler(request)
execution_order.append("third-after")
@@ -129,10 +141,14 @@ class TestChainModelCallHandlers:
"""Test inner handler retrying before outer sees response."""
inner_attempts = []
def outer_passthrough(request, handler):
def outer_passthrough(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
return handler(request)
def inner_with_retry(request, handler):
def inner_with_retry(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse | AIMessage:
for attempt in range(3):
inner_attempts.append(attempt)
try:
@@ -147,7 +163,7 @@ class TestChainModelCallHandlers:
call_count = {"value": 0}
def mock_base_handler(req):
def mock_base_handler(req: ModelRequest) -> ModelResponse:
call_count["value"] += 1
if call_count["value"] < 3:
msg = "fail"
@@ -163,20 +179,24 @@ class TestChainModelCallHandlers:
def test_error_to_success_conversion(self) -> None:
"""Test handler converting error to success response."""
def outer_error_handler(request, handler):
def outer_error_handler(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse | AIMessage:
try:
return handler(request)
except Exception:
# Middleware can return AIMessage - it will be normalized to ModelResponse
return AIMessage(content="Fallback response")
def inner_passthrough(request, handler):
def inner_passthrough(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
return handler(request)
composed = _chain_model_call_handlers([outer_error_handler, inner_passthrough])
assert composed is not None
def mock_base_handler(req):
def mock_base_handler(req: ModelRequest) -> ModelResponse:
msg = "Model failed"
raise ValueError(msg)
@@ -191,13 +211,17 @@ class TestChainModelCallHandlers:
"""Test handlers modifying the request."""
requests_seen = []
def outer_add_context(request, handler):
def outer_add_context(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
modified_request = create_test_request(
messages=[*request.messages], system_prompt="Added by outer"
)
return handler(modified_request)
def inner_track_request(request, handler):
def inner_track_request(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
requests_seen.append(request.system_prompt)
return handler(request)
@@ -212,15 +236,26 @@ class TestChainModelCallHandlers:
def test_composition_preserves_state_and_runtime(self) -> None:
"""Test that state and runtime are passed through composition."""
class CustomState(AgentState[Any]):
test: str
class CustomContext(TypedDict):
test: str
state_values = []
runtime_values = []
def outer(request, handler):
def outer(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
state_values.append(("outer", request.state))
runtime_values.append(("outer", request.runtime))
return handler(request)
def inner(request, handler):
def inner(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
state_values.append(("inner", request.state))
runtime_values.append(("inner", request.runtime))
return handler(request)
@@ -228,8 +263,8 @@ class TestChainModelCallHandlers:
composed = _chain_model_call_handlers([outer, inner])
assert composed is not None
test_state = {"test": "state"}
test_runtime = {"test": "runtime"}
test_state = CustomState(messages=[], test="state")
test_runtime = Runtime(context=CustomContext(test="runtime"))
# Create request with state and runtime
test_request = create_test_request(state=test_state, runtime=test_runtime)
@@ -245,11 +280,15 @@ class TestChainModelCallHandlers:
"""Test handler that retries multiple times."""
call_count = {"value": 0}
def outer_counts_calls(request, handler):
def outer_counts_calls(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
call_count["value"] += 1
return handler(request)
def inner_retries(request, handler):
def inner_retries(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
try:
return handler(request)
except ValueError:
@@ -261,7 +300,7 @@ class TestChainModelCallHandlers:
attempt = {"value": 0}
def mock_base_handler(req):
def mock_base_handler(req: ModelRequest) -> ModelResponse:
attempt["value"] += 1
if attempt["value"] == 1:
msg = "fail"