mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
chore(langchain): fix types in test_composition (#34580)
This commit is contained in:
committed by
GitHub
parent
3b65985551
commit
7979fd3d9f
@@ -1,19 +1,19 @@
|
||||
"""Unit tests for _chain_model_call_handlers handler composition."""
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.factory import _chain_model_call_handlers
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def create_test_request(**kwargs):
|
||||
def create_test_request(**kwargs: Any) -> ModelRequest:
|
||||
"""Helper to create a `ModelRequest` with sensible defaults."""
|
||||
defaults = {
|
||||
defaults: dict[str, Any] = {
|
||||
"messages": [],
|
||||
"model": None,
|
||||
"system_prompt": None,
|
||||
@@ -27,10 +27,10 @@ def create_test_request(**kwargs):
|
||||
return ModelRequest(**defaults)
|
||||
|
||||
|
||||
def create_mock_base_handler(content="test"):
|
||||
def create_mock_base_handler(content: str = "test") -> Callable[[ModelRequest], ModelResponse]:
|
||||
"""Helper to create a base handler that returns `ModelResponse`."""
|
||||
|
||||
def mock_base_handler(req):
|
||||
def mock_base_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content=content)], structured_response=None)
|
||||
|
||||
return mock_base_handler
|
||||
@@ -47,7 +47,9 @@ class TestChainModelCallHandlers:
|
||||
def test_single_handler_returns_unchanged(self) -> None:
|
||||
"""Test that single handler is wrapped to normalize output."""
|
||||
|
||||
def handler(request, base_handler):
|
||||
def handler(
|
||||
request: ModelRequest, base_handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
return base_handler(request)
|
||||
|
||||
result = _chain_model_call_handlers([handler])
|
||||
@@ -59,13 +61,17 @@ class TestChainModelCallHandlers:
|
||||
"""Test basic composition of two handlers."""
|
||||
execution_order = []
|
||||
|
||||
def outer(request, handler):
|
||||
def outer(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
execution_order.append("outer-before")
|
||||
result = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return result
|
||||
|
||||
def inner(request, handler):
|
||||
def inner(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
execution_order.append("inner-before")
|
||||
result = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
@@ -90,19 +96,25 @@ class TestChainModelCallHandlers:
|
||||
"""Test composition of three handlers."""
|
||||
execution_order = []
|
||||
|
||||
def first(request, handler):
|
||||
def first(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
execution_order.append("first-before")
|
||||
result = handler(request)
|
||||
execution_order.append("first-after")
|
||||
return result
|
||||
|
||||
def second(request, handler):
|
||||
def second(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
execution_order.append("second-before")
|
||||
result = handler(request)
|
||||
execution_order.append("second-after")
|
||||
return result
|
||||
|
||||
def third(request, handler):
|
||||
def third(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
execution_order.append("third-before")
|
||||
result = handler(request)
|
||||
execution_order.append("third-after")
|
||||
@@ -129,10 +141,14 @@ class TestChainModelCallHandlers:
|
||||
"""Test inner handler retrying before outer sees response."""
|
||||
inner_attempts = []
|
||||
|
||||
def outer_passthrough(request, handler):
|
||||
def outer_passthrough(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
return handler(request)
|
||||
|
||||
def inner_with_retry(request, handler):
|
||||
def inner_with_retry(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse | AIMessage:
|
||||
for attempt in range(3):
|
||||
inner_attempts.append(attempt)
|
||||
try:
|
||||
@@ -147,7 +163,7 @@ class TestChainModelCallHandlers:
|
||||
|
||||
call_count = {"value": 0}
|
||||
|
||||
def mock_base_handler(req):
|
||||
def mock_base_handler(req: ModelRequest) -> ModelResponse:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] < 3:
|
||||
msg = "fail"
|
||||
@@ -163,20 +179,24 @@ class TestChainModelCallHandlers:
|
||||
def test_error_to_success_conversion(self) -> None:
|
||||
"""Test handler converting error to success response."""
|
||||
|
||||
def outer_error_handler(request, handler):
|
||||
def outer_error_handler(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse | AIMessage:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
# Middleware can return AIMessage - it will be normalized to ModelResponse
|
||||
return AIMessage(content="Fallback response")
|
||||
|
||||
def inner_passthrough(request, handler):
|
||||
def inner_passthrough(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
return handler(request)
|
||||
|
||||
composed = _chain_model_call_handlers([outer_error_handler, inner_passthrough])
|
||||
assert composed is not None
|
||||
|
||||
def mock_base_handler(req):
|
||||
def mock_base_handler(req: ModelRequest) -> ModelResponse:
|
||||
msg = "Model failed"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -191,13 +211,17 @@ class TestChainModelCallHandlers:
|
||||
"""Test handlers modifying the request."""
|
||||
requests_seen = []
|
||||
|
||||
def outer_add_context(request, handler):
|
||||
def outer_add_context(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
modified_request = create_test_request(
|
||||
messages=[*request.messages], system_prompt="Added by outer"
|
||||
)
|
||||
return handler(modified_request)
|
||||
|
||||
def inner_track_request(request, handler):
|
||||
def inner_track_request(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
requests_seen.append(request.system_prompt)
|
||||
return handler(request)
|
||||
|
||||
@@ -212,15 +236,26 @@ class TestChainModelCallHandlers:
|
||||
|
||||
def test_composition_preserves_state_and_runtime(self) -> None:
|
||||
"""Test that state and runtime are passed through composition."""
|
||||
|
||||
class CustomState(AgentState[Any]):
|
||||
test: str
|
||||
|
||||
class CustomContext(TypedDict):
|
||||
test: str
|
||||
|
||||
state_values = []
|
||||
runtime_values = []
|
||||
|
||||
def outer(request, handler):
|
||||
def outer(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
state_values.append(("outer", request.state))
|
||||
runtime_values.append(("outer", request.runtime))
|
||||
return handler(request)
|
||||
|
||||
def inner(request, handler):
|
||||
def inner(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
state_values.append(("inner", request.state))
|
||||
runtime_values.append(("inner", request.runtime))
|
||||
return handler(request)
|
||||
@@ -228,8 +263,8 @@ class TestChainModelCallHandlers:
|
||||
composed = _chain_model_call_handlers([outer, inner])
|
||||
assert composed is not None
|
||||
|
||||
test_state = {"test": "state"}
|
||||
test_runtime = {"test": "runtime"}
|
||||
test_state = CustomState(messages=[], test="state")
|
||||
test_runtime = Runtime(context=CustomContext(test="runtime"))
|
||||
|
||||
# Create request with state and runtime
|
||||
test_request = create_test_request(state=test_state, runtime=test_runtime)
|
||||
@@ -245,11 +280,15 @@ class TestChainModelCallHandlers:
|
||||
"""Test handler that retries multiple times."""
|
||||
call_count = {"value": 0}
|
||||
|
||||
def outer_counts_calls(request, handler):
|
||||
def outer_counts_calls(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
call_count["value"] += 1
|
||||
return handler(request)
|
||||
|
||||
def inner_retries(request, handler):
|
||||
def inner_retries(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
return handler(request)
|
||||
except ValueError:
|
||||
@@ -261,7 +300,7 @@ class TestChainModelCallHandlers:
|
||||
|
||||
attempt = {"value": 0}
|
||||
|
||||
def mock_base_handler(req):
|
||||
def mock_base_handler(req: ModelRequest) -> ModelResponse:
|
||||
attempt["value"] += 1
|
||||
if attempt["value"] == 1:
|
||||
msg = "fail"
|
||||
|
||||
Reference in New Issue
Block a user