Compare commits

...

4 Commits

Author SHA1 Message Date
Sydney Runkle
74463b299e Merge branch 'sr/multiple-typing-fixes' of https://github.com/langchain-ai/langchain into sr/multiple-typing-fixes 2025-12-19 12:53:30 -05:00
Sydney Runkle
20e5fd4186 plumbing through theoretically 2025-12-19 12:53:25 -05:00
Sydney Runkle
46ad97c297 Update libs/langchain_v1/tests/integration_tests/agents/middleware/test_shell_tool_integration.py 2025-12-19 12:28:02 -05:00
Sydney Runkle
73d9061764 examples 2025-12-19 10:43:51 -05:00
3 changed files with 614 additions and 91 deletions

View File

@@ -86,13 +86,19 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
],
ModelResponse | AIMessage,
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
],
ModelResponse,
]
| None
@@ -140,8 +146,8 @@ def _chain_model_call_handlers(
single_handler = handlers[0]
def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[Any, ContextT],
handler: Callable[[ModelRequest[Any, ContextT]], ModelResponse],
) -> ModelResponse:
result = single_handler(request, handler)
return _normalize_to_model_response(result)
@@ -150,25 +156,34 @@ def _chain_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
],
ModelResponse | AIMessage,
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
],
ModelResponse | AIMessage,
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
],
ModelResponse,
]:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[Any, ContextT],
handler: Callable[[ModelRequest[Any, ContextT]], ModelResponse],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
def inner_handler(req: ModelRequest) -> ModelResponse:
def inner_handler(req: ModelRequest[Any, ContextT]) -> ModelResponse:
inner_result = inner(req, handler)
return _normalize_to_model_response(inner_result)
@@ -185,8 +200,8 @@ def _chain_model_call_handlers(
# Wrap to ensure final return type is exactly ModelResponse
def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[Any, ContextT],
handler: Callable[[ModelRequest[Any, ContextT]], ModelResponse],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = result(request, handler)
@@ -198,13 +213,19 @@ def _chain_model_call_handlers(
def _chain_async_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
],
Awaitable[ModelResponse | AIMessage],
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
],
Awaitable[ModelResponse],
]
| None
@@ -225,8 +246,8 @@ def _chain_async_model_call_handlers(
single_handler = handlers[0]
async def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[Any, ContextT],
handler: Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
result = await single_handler(request, handler)
return _normalize_to_model_response(result)
@@ -235,25 +256,34 @@ def _chain_async_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
],
Awaitable[ModelResponse | AIMessage],
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
],
Awaitable[ModelResponse | AIMessage],
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[
ModelRequest[Any, ContextT],
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
],
Awaitable[ModelResponse],
]:
"""Compose two async handlers where outer wraps inner."""
async def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[Any, ContextT],
handler: Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
async def inner_handler(req: ModelRequest) -> ModelResponse:
async def inner_handler(req: ModelRequest[Any, ContextT]) -> ModelResponse:
inner_result = await inner(req, handler)
return _normalize_to_model_response(inner_result)
@@ -270,8 +300,8 @@ def _chain_async_model_call_handlers(
# Wrap to ensure final return type is exactly ModelResponse
async def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[Any, ContextT],
handler: Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = await result(request, handler)
@@ -973,7 +1003,9 @@ def create_agent(
return {"messages": [output]}
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
def _get_bound_model(
request: ModelRequest[Any, ContextT],
) -> tuple[Runnable, ResponseFormat | None]:
"""Get the model with appropriate tool bindings.
Performs auto-detection of strategy if needed based on model capabilities.
@@ -1087,7 +1119,7 @@ def create_agent(
)
return request.model.bind(**request.model_settings), None
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
def _execute_model_sync(request: ModelRequest[Any, ContextT]) -> ModelResponse:
"""Execute model and return response.
This is the core model execution logic wrapped by `wrap_model_call` handlers.
@@ -1140,7 +1172,7 @@ def create_agent(
return state_updates
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
async def _execute_model_async(request: ModelRequest[Any, ContextT]) -> ModelResponse:
"""Execute model asynchronously and return response.
This is the core async model execution logic wrapped by `wrap_model_call`

View File

@@ -67,7 +67,9 @@ __all__ = [
JumpTo = Literal["tools", "model", "end"]
"""Destination to jump to when a middleware node returns."""
ResponseT = TypeVar("ResponseT")
ResponseT = TypeVar("ResponseT", default=Any)
# StateT uses string forward references since AgentState is defined later
StateT = TypeVar("StateT", bound="AgentState", default="AgentState")
class _ModelRequestOverrides(TypedDict, total=False):
@@ -83,7 +85,7 @@ class _ModelRequestOverrides(TypedDict, total=False):
@dataclass(init=False)
class ModelRequest:
class ModelRequest(Generic[StateT, ContextT]):
"""Model request information for the agent."""
model: BaseChatModel
@@ -92,8 +94,8 @@ class ModelRequest:
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
state: AgentState
runtime: Runtime[ContextT] # type: ignore[valid-type]
state: StateT
runtime: Runtime[ContextT]
model_settings: dict[str, Any] = field(default_factory=dict)
def __init__(
@@ -106,7 +108,7 @@ class ModelRequest:
tool_choice: Any | None = None,
tools: list[BaseTool | dict] | None = None,
response_format: ResponseFormat | None = None,
state: AgentState | None = None,
state: StateT | None = None,
runtime: Runtime[ContextT] | None = None,
model_settings: dict[str, Any] | None = None,
) -> None:
@@ -140,7 +142,7 @@ class ModelRequest:
self.tool_choice = tool_choice
self.tools = tools if tools is not None else []
self.response_format = response_format
self.state = state if state is not None else {"messages": []}
self.state = state if state is not None else cast("StateT", {"messages": []})
self.runtime = runtime # type: ignore[assignment]
self.model_settings = model_settings if model_settings is not None else {}
@@ -189,7 +191,9 @@ class ModelRequest:
)
object.__setattr__(self, name, value)
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
def override(
self, **overrides: Unpack[_ModelRequestOverrides]
) -> ModelRequest[StateT, ContextT]:
"""Replace the request with a new request with the given overrides.
Returns a new `ModelRequest` instance with the specified attributes replaced.
@@ -322,7 +326,6 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
structured_response: NotRequired[ResponseT]
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
@@ -383,8 +386,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
) -> ModelCallResult:
"""Intercept and control model execution via handler callback.
@@ -478,8 +481,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Intercept and control async model execution via handler callback.
@@ -698,18 +701,24 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
...
class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
class _SyncCallableReturningSystemMessage(Protocol[StateT, ContextT]):
"""Sync callable that returns a prompt string or SystemMessage given `ModelRequest`."""
def __call__(
self, request: ModelRequest
) -> str | SystemMessage | Awaitable[str | SystemMessage]:
def __call__(self, request: ModelRequest[StateT, ContextT]) -> str | SystemMessage:
"""Generate a system prompt string or SystemMessage based on the request."""
...
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable for model call interception with handler callback.
class _AsyncCallableReturningSystemMessage(Protocol[StateT, ContextT]):
"""Async callable that returns a prompt string or SystemMessage given `ModelRequest`."""
def __call__(self, request: ModelRequest[StateT, ContextT]) -> Awaitable[str | SystemMessage]:
"""Generate a system prompt string or SystemMessage based on the request."""
...
class _SyncCallableReturningModelResponse(Protocol[StateT, ContextT]):
"""Sync callable for model call interception with handler callback.
Receives handler callback to execute model and returns `ModelResponse` or
`AIMessage`.
@@ -717,15 +726,31 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
def __call__(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
) -> ModelCallResult:
"""Intercept model execution via handler callback."""
...
class _CallableReturningToolResponse(Protocol):
"""Callable for tool call interception with handler callback.
class _AsyncCallableReturningModelResponse(Protocol[StateT, ContextT]):
"""Async callable for model call interception with handler callback.
Receives async handler callback to execute model and returns `ModelResponse` or
`AIMessage`.
"""
def __call__(
self,
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
) -> Awaitable[ModelCallResult]:
"""Intercept model execution via async handler callback."""
...
class _SyncCallableReturningToolResponse(Protocol):
"""Sync callable for tool call interception with handler callback.
Receives handler callback to execute tool and returns final `ToolMessage` or
`Command`.
@@ -740,6 +765,22 @@ class _CallableReturningToolResponse(Protocol):
...
class _AsyncCallableReturningToolResponse(Protocol):
"""Async callable for tool call interception with handler callback.
Receives async handler callback to execute tool and returns final `ToolMessage` or
`Command`.
"""
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> Awaitable[ToolMessage | Command]:
"""Intercept tool execution via async handler callback."""
...
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
@@ -1385,24 +1426,44 @@ def after_agent(
@overload
def dynamic_prompt(
func: _CallableReturningSystemMessage[StateT, ContextT],
func: _SyncCallableReturningSystemMessage[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def dynamic_prompt(
func: _AsyncCallableReturningSystemMessage[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def dynamic_prompt(
func: None = None,
*,
state_schema: type[StateT] | None = None,
) -> Callable[
[_CallableReturningSystemMessage[StateT, ContextT]],
[
_SyncCallableReturningSystemMessage[StateT, ContextT]
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
],
AgentMiddleware[StateT, ContextT],
]: ...
def dynamic_prompt(
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
func: (
_SyncCallableReturningSystemMessage[StateT, ContextT]
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
| None
) = None,
*,
state_schema: type[StateT] | None = None,
) -> (
Callable[
[_CallableReturningSystemMessage[StateT, ContextT]],
[
_SyncCallableReturningSystemMessage[StateT, ContextT]
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
@@ -1418,6 +1479,9 @@ def dynamic_prompt(
Must accept: `request: ModelRequest` - Model request (contains state and
runtime)
state_schema: Optional custom state schema type.
If not provided, uses the default `AgentState` schema.
Returns:
Either an `AgentMiddleware` instance (if func is provided) or a decorator
@@ -1456,18 +1520,22 @@ def dynamic_prompt(
"""
def decorator(
func: _CallableReturningSystemMessage[StateT, ContextT],
func: (
_SyncCallableReturningSystemMessage[StateT, ContextT]
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
),
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
if is_async:
async_func = cast("_AsyncCallableReturningSystemMessage[StateT, ContextT]", func)
async def async_wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
) -> ModelCallResult:
prompt = await func(request) # type: ignore[misc]
prompt = await async_func(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
@@ -1480,18 +1548,20 @@ def dynamic_prompt(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"state_schema": state_schema or AgentState,
"tools": [],
"awrap_model_call": async_wrapped,
},
)()
sync_func = cast("_SyncCallableReturningSystemMessage[StateT, ContextT]", func)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
) -> ModelCallResult:
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
prompt = sync_func(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
@@ -1500,11 +1570,11 @@ def dynamic_prompt(
async def async_wrapped_from_sync(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
) -> ModelCallResult:
# Delegate to sync function
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
prompt = sync_func(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
@@ -1517,7 +1587,7 @@ def dynamic_prompt(
middleware_name,
(AgentMiddleware,),
{
"state_schema": AgentState,
"state_schema": state_schema or AgentState,
"tools": [],
"wrap_model_call": wrapped,
"awrap_model_call": async_wrapped_from_sync,
@@ -1531,7 +1601,13 @@ def dynamic_prompt(
@overload
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT],
func: _SyncCallableReturningModelResponse[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def wrap_model_call(
func: _AsyncCallableReturningModelResponse[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@@ -1543,20 +1619,30 @@ def wrap_model_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningModelResponse[StateT, ContextT]],
[
_SyncCallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT]
],
AgentMiddleware[StateT, ContextT],
]: ...
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
func: (
_SyncCallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT]
| None
) = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningModelResponse[StateT, ContextT]],
[
_SyncCallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT]
],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
@@ -1637,18 +1723,22 @@ def wrap_model_call(
"""
def decorator(
func: _CallableReturningModelResponse[StateT, ContextT],
func: (
_SyncCallableReturningModelResponse[StateT, ContextT]
| _AsyncCallableReturningModelResponse[StateT, ContextT]
),
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
if is_async:
async_func = cast("_AsyncCallableReturningModelResponse[StateT, ContextT]", func)
async def async_wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await func(request, handler) # type: ignore[misc, arg-type]
return await async_func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
@@ -1664,12 +1754,14 @@ def wrap_model_call(
},
)()
sync_func = cast("_SyncCallableReturningModelResponse[StateT, ContextT]", func)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
) -> ModelCallResult:
return func(request, handler)
return sync_func(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
@@ -1690,8 +1782,14 @@ def wrap_model_call(
@overload
def wrap_tool_call(
func: _CallableReturningToolResponse,
) -> AgentMiddleware: ...
func: _SyncCallableReturningToolResponse,
) -> AgentMiddleware[AgentState, None]: ...
@overload
def wrap_tool_call(
func: _AsyncCallableReturningToolResponse,
) -> AgentMiddleware[AgentState, None]: ...
@overload
@@ -1701,22 +1799,22 @@ def wrap_tool_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningToolResponse],
AgentMiddleware,
[_SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse],
AgentMiddleware[AgentState, None],
]: ...
def wrap_tool_call(
func: _CallableReturningToolResponse | None = None,
func: _SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse | None = None,
*,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningToolResponse],
AgentMiddleware,
[_SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse],
AgentMiddleware[AgentState, None],
]
| AgentMiddleware
| AgentMiddleware[AgentState, None]
):
"""Create middleware with `wrap_tool_call` hook from a function.
@@ -1797,18 +1895,19 @@ def wrap_tool_call(
"""
def decorator(
func: _CallableReturningToolResponse,
) -> AgentMiddleware:
func: _SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse,
) -> AgentMiddleware[AgentState, None]:
is_async = iscoroutinefunction(func)
if is_async:
async_func = cast("_AsyncCallableReturningToolResponse", func)
async def async_wrapped(
_self: AgentMiddleware,
_self: AgentMiddleware[AgentState, None],
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
return await func(request, handler) # type: ignore[arg-type,misc]
return await async_func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
@@ -1824,12 +1923,14 @@ def wrap_tool_call(
},
)()
sync_func = cast("_SyncCallableReturningToolResponse", func)
def wrapped(
_self: AgentMiddleware,
_self: AgentMiddleware[AgentState, None],
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
return func(request, handler)
return sync_func(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))

View File

@@ -0,0 +1,390 @@
"""Tests demonstrating proper typing support for middleware.
This test file verifies that:
1. ModelRequest is properly generic over StateT and ContextT
2. Async middleware decorators work without type errors
3. Sync middleware decorators work without type errors
4. Custom context types flow through properly
5. Handler callbacks have correct async/sync signatures
These tests should pass mypy type checking without any type: ignore comments.
"""
from collections.abc import Awaitable, Callable
from typing import TypedDict
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
after_agent,
after_model,
before_agent,
before_model,
dynamic_prompt,
wrap_model_call,
wrap_tool_call,
)
# Custom context type for testing
class ServiceContext(TypedDict):
"""Custom context for service-level information."""
user_id: str
session_id: str
environment: str
class CustomState(AgentState):
"""Custom state extending AgentState."""
custom_field: str
# =============================================================================
# Test 1: ModelRequest generic typing with custom context
# =============================================================================
def test_model_request_generic_context_typing() -> None:
"""Test that ModelRequest[StateT, ContextT] properly types the state and runtime fields."""
# Create a mock model
mock_model = MagicMock()
# Create ModelRequest with explicit state and context type annotation
request: ModelRequest[AgentState, ServiceContext] = ModelRequest(
model=mock_model,
messages=[HumanMessage(content="Hello")],
)
# The request should be created without type errors
assert request.model == mock_model
assert len(request.messages) == 1
# =============================================================================
# Test 2: Sync dynamic_prompt decorator with proper typing
# =============================================================================
def test_sync_dynamic_prompt_typing() -> None:
"""Test that sync @dynamic_prompt decorator works without type errors."""
@dynamic_prompt
def my_prompt(request: ModelRequest[AgentState, ServiceContext]) -> str:
# This should work without type: ignore - accessing generic ModelRequest
return f"System prompt for messages: {len(request.messages)}"
# The decorator should return an AgentMiddleware
assert isinstance(my_prompt, AgentMiddleware)
def test_sync_dynamic_prompt_returning_system_message() -> None:
"""Test that sync @dynamic_prompt can return SystemMessage."""
@dynamic_prompt
def my_prompt(request: ModelRequest[AgentState, None]) -> SystemMessage:
return SystemMessage(content="You are a helpful assistant")
assert isinstance(my_prompt, AgentMiddleware)
# =============================================================================
# Test 3: Async dynamic_prompt decorator with proper typing
# =============================================================================
def test_async_dynamic_prompt_typing() -> None:
"""Test that async @dynamic_prompt decorator works without type errors."""
@dynamic_prompt
async def my_async_prompt(request: ModelRequest[AgentState, ServiceContext]) -> str:
# Async function should work without type errors
return "Async system prompt"
assert isinstance(my_async_prompt, AgentMiddleware)
def test_async_dynamic_prompt_returning_system_message() -> None:
"""Test that async @dynamic_prompt can return SystemMessage."""
@dynamic_prompt
async def my_async_prompt(request: ModelRequest[AgentState, None]) -> SystemMessage:
return SystemMessage(content="Async system message")
assert isinstance(my_async_prompt, AgentMiddleware)
# =============================================================================
# Test 4: Sync wrap_model_call decorator with proper handler typing
# =============================================================================
def test_sync_wrap_model_call_typing() -> None:
"""Test that sync @wrap_model_call decorator properly types the handler."""
@wrap_model_call
def retry_middleware(
request: ModelRequest[AgentState, ServiceContext],
handler: Callable[[ModelRequest[AgentState, ServiceContext]], ModelResponse],
) -> ModelCallResult:
# Handler should be typed as sync - no Awaitable
return handler(request)
assert isinstance(retry_middleware, AgentMiddleware)
def test_sync_wrap_model_call_returning_ai_message() -> None:
"""Test that sync @wrap_model_call can return AIMessage directly."""
@wrap_model_call
def simple_middleware(
request: ModelRequest[AgentState, None],
handler: Callable[[ModelRequest[AgentState, None]], ModelResponse],
) -> ModelCallResult:
# Can return AIMessage directly (converted automatically)
return AIMessage(content="Simple response")
assert isinstance(simple_middleware, AgentMiddleware)
# =============================================================================
# Test 5: Async wrap_model_call decorator with proper handler typing
# =============================================================================
def test_async_wrap_model_call_typing() -> None:
"""Test that async @wrap_model_call decorator properly types the async handler."""
@wrap_model_call
async def async_retry_middleware(
request: ModelRequest[AgentState, ServiceContext],
handler: Callable[[ModelRequest[AgentState, ServiceContext]], Awaitable[ModelResponse]],
) -> ModelCallResult:
# Handler should be typed as async - returns Awaitable
return await handler(request)
assert isinstance(async_retry_middleware, AgentMiddleware)
def test_async_wrap_model_call_with_error_handling() -> None:
"""Test async @wrap_model_call with try/except pattern."""
@wrap_model_call
async def error_handling_middleware(
request: ModelRequest[AgentState, None],
handler: Callable[[ModelRequest[AgentState, None]], Awaitable[ModelResponse]],
) -> ModelCallResult:
try:
return await handler(request)
except Exception:
return AIMessage(content="Error occurred")
assert isinstance(error_handling_middleware, AgentMiddleware)
# =============================================================================
# Test 6: Sync wrap_tool_call decorator with proper handler typing
# =============================================================================
def test_sync_wrap_tool_call_typing() -> None:
"""Test that sync @wrap_tool_call decorator properly types the handler."""
@wrap_tool_call
def tool_error_handler(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
try:
return handler(request)
except Exception as e:
return ToolMessage(
content=str(e),
tool_call_id=request.tool_call["id"],
)
assert isinstance(tool_error_handler, AgentMiddleware)
# =============================================================================
# Test 7: Async wrap_tool_call decorator with proper handler typing
# =============================================================================
def test_async_wrap_tool_call_typing() -> None:
"""Test that async @wrap_tool_call decorator properly types the async handler."""
@wrap_tool_call
async def async_tool_error_handler(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
try:
return await handler(request)
except Exception as e:
return ToolMessage(
content=str(e),
tool_call_id=request.tool_call["id"],
)
assert isinstance(async_tool_error_handler, AgentMiddleware)
# =============================================================================
# Test 8: before_model/after_model decorators with custom state
# =============================================================================
def test_before_model_with_custom_state_typing() -> None:
"""Test @before_model decorator with custom state schema."""
@before_model(state_schema=CustomState)
def log_before_model(
state: CustomState,
runtime: object, # Runtime type comes from langgraph
) -> dict[str, object] | None:
# Should have access to custom_field without type errors
_ = state.get("custom_field") # Access custom field
return None
assert isinstance(log_before_model, AgentMiddleware)
assert log_before_model.state_schema == CustomState
def test_after_model_with_custom_state_typing() -> None:
"""Test @after_model decorator with custom state schema."""
@after_model(state_schema=CustomState)
def log_after_model(
state: CustomState,
runtime: object,
) -> dict[str, object] | None:
return {"custom_field": "updated"}
assert isinstance(log_after_model, AgentMiddleware)
# =============================================================================
# Test 9: before_agent/after_agent decorators
# =============================================================================
def test_before_agent_async_typing() -> None:
"""Test async @before_agent decorator."""
@before_agent
async def setup_agent(
state: AgentState,
runtime: object,
) -> dict[str, object] | None:
return None
assert isinstance(setup_agent, AgentMiddleware)
def test_after_agent_async_typing() -> None:
"""Test async @after_agent decorator."""
@after_agent
async def cleanup_agent(
state: AgentState,
runtime: object,
) -> dict[str, object] | None:
return None
assert isinstance(cleanup_agent, AgentMiddleware)
# =============================================================================
# Test 10: Class-based middleware with proper generic typing
# =============================================================================
class TypedMiddleware(AgentMiddleware[CustomState, ServiceContext]):
"""Class-based middleware with explicit type parameters."""
state_schema = CustomState
def before_model(
self,
state: CustomState,
runtime: object,
) -> dict[str, object] | None:
# State is properly typed as CustomState
return None
def test_class_based_middleware_typing() -> None:
"""Test class-based middleware with explicit generics."""
middleware = TypedMiddleware()
assert middleware.state_schema == CustomState
# =============================================================================
# Test 11: ModelRequest.override() preserves generic type
# =============================================================================
def test_model_request_override_preserves_generic() -> None:
"""Test that ModelRequest.override() returns properly typed ModelRequest."""
mock_model = MagicMock()
request: ModelRequest[AgentState, ServiceContext] = ModelRequest(
model=mock_model,
messages=[HumanMessage(content="Hello")],
)
# override() should return ModelRequest[AgentState, ServiceContext], not ModelRequest[Any, Any]
new_request = request.override(system_message=SystemMessage(content="New system prompt"))
# This should be type-safe
assert new_request.system_message is not None
assert new_request.system_message.content == "New system prompt"
# =============================================================================
# Test 12: Multiple middleware in a list (simulating create_agent usage)
# =============================================================================
def test_middleware_list_typing() -> None:
"""Test that middleware can be collected in a properly typed list."""
@dynamic_prompt
async def system_prompt(request: ModelRequest[AgentState, ServiceContext]) -> SystemMessage:
return SystemMessage(content="System")
@wrap_model_call
async def censor_response(
request: ModelRequest[AgentState, ServiceContext],
handler: Callable[[ModelRequest[AgentState, ServiceContext]], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(request)
@wrap_tool_call
async def handle_errors(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
try:
return await handler(request)
except Exception as e:
return ToolMessage(content=str(e), tool_call_id=request.tool_call["id"])
# All middleware should be assignable to a list of AgentMiddleware
# Note: The decorators return AgentMiddleware with inferred generic parameters
middleware_list: list[AgentMiddleware[AgentState, ServiceContext]] = [
system_prompt,
censor_response,
]
assert len(middleware_list) == 2