Compare commits

...

1 Commits

Author SHA1 Message Date
Sydney Runkle
7023b6f496 experiment w/ typing 2026-02-12 18:08:44 -05:00
5 changed files with 249 additions and 28 deletions

View File

@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.typing import ContextT
from typing_extensions import NotRequired, override
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextT,
PrivateStateAttr,
ResponseT,
hook_config,

View File

@@ -33,9 +33,13 @@ from langchain_core.messages import (
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
if TYPE_CHECKING:
from langgraph._internal._typing import StateLike
ContextT = TypeVar("ContextT", bound="StateLike | None", default=Any)
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool
@@ -90,7 +94,7 @@ class ModelRequest(Generic[ContextT]):
"""Model request information for the agent.
Type Parameters:
ContextT: The type of the runtime context. Defaults to `None` if not specified.
ContextT: The type of the runtime context. Defaults to `Any` if not specified.
"""
model: BaseChatModel
@@ -385,7 +389,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
Type Parameters:
StateT: The type of the agent state. Defaults to `AgentState[Any]`.
ContextT: The type of the runtime context. Defaults to `None`.
ContextT: The type of the runtime context. Defaults to `Any`.
ResponseT: The type of the structured response. Defaults to `Any`.
"""
@@ -850,6 +854,38 @@ class _CallableReturningToolResponse(Protocol):
...
class _AsyncCallableReturningModelResponse(Protocol[StateT_contra, ContextT, ResponseT]): # type: ignore[misc]
"""Async callable for model call interception with async handler callback.
Receives async handler callback to execute model and returns awaitable
`ModelResponse` or `AIMessage`.
"""
def __call__(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> Awaitable[ModelResponse[ResponseT] | AIMessage]:
"""Intercept async model execution via handler callback."""
...
class _AsyncCallableReturningToolResponse(Protocol):
"""Async callable for tool call interception with async handler callback.
Receives async handler callback to execute tool and returns awaitable
`ToolMessage` or `Command`.
"""
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> Awaitable[ToolMessage | Command[Any]]:
"""Intercept async tool execution via handler callback."""
...
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
@@ -1733,6 +1769,12 @@ def dynamic_prompt(
return decorator
@overload
def wrap_model_call(
func: _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
@@ -1747,20 +1789,28 @@ def wrap_model_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
[
_CallableReturningModelResponse[StateT, ContextT, ResponseT]
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]
],
AgentMiddleware[StateT, ContextT],
]: ...
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT] | None = None,
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT]
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]
| None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
[
_CallableReturningModelResponse[StateT, ContextT, ResponseT]
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]
],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
@@ -1841,18 +1891,24 @@ def wrap_model_call(
"""
def decorator(
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT]
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
if is_async:
# iscoroutinefunction is not a TypeGuard, so narrow manually
_async_func = cast(
"_AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]",
func,
)
async def async_wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | AIMessage:
return await func(request, handler) # type: ignore[misc, arg-type]
return await _async_func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
@@ -1868,12 +1924,14 @@ def wrap_model_call(
},
)()
_sync_func = cast("_CallableReturningModelResponse[StateT, ContextT, ResponseT]", func)
def wrapped(
_self: AgentMiddleware[StateT, ContextT],
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
) -> ModelResponse[ResponseT] | AIMessage:
return func(request, handler)
return _sync_func(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
@@ -1892,6 +1950,12 @@ def wrap_model_call(
return decorator
@overload
def wrap_tool_call(
func: _AsyncCallableReturningToolResponse,
) -> AgentMiddleware: ...
@overload
def wrap_tool_call(
func: _CallableReturningToolResponse,
@@ -1905,19 +1969,19 @@ def wrap_tool_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningToolResponse],
[_CallableReturningToolResponse | _AsyncCallableReturningToolResponse],
AgentMiddleware,
]: ...
def wrap_tool_call(
func: _CallableReturningToolResponse | None = None,
func: _CallableReturningToolResponse | _AsyncCallableReturningToolResponse | None = None,
*,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningToolResponse],
[_CallableReturningToolResponse | _AsyncCallableReturningToolResponse],
AgentMiddleware,
]
| AgentMiddleware
@@ -2001,18 +2065,20 @@ def wrap_tool_call(
"""
def decorator(
func: _CallableReturningToolResponse,
func: _CallableReturningToolResponse | _AsyncCallableReturningToolResponse,
) -> AgentMiddleware:
is_async = iscoroutinefunction(func)
if is_async:
# iscoroutinefunction is not a TypeGuard, so narrow manually
_async_func = cast("_AsyncCallableReturningToolResponse", func)
async def async_wrapped(
_self: AgentMiddleware,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await func(request, handler) # type: ignore[arg-type,misc]
return await _async_func(request, handler)
middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
@@ -2028,12 +2094,14 @@ def wrap_tool_call(
},
)()
_sync_func = cast("_CallableReturningToolResponse", func)
def wrapped(
_self: AgentMiddleware,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
return func(request, handler)
return _sync_func(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))

View File

@@ -39,7 +39,7 @@ if TYPE_CHECKING:
class OldStyleMiddleware1(AgentMiddleware):
"""Middleware with no type parameters at all - most common old pattern."""
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None:
def before_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> dict[str, Any] | None:
# Simple middleware that just logs or does something
return None
@@ -104,7 +104,7 @@ class OldStyleMiddleware4(AgentMiddleware[AgentState[Any], MyContext]):
# OLD PATTERN 5: Decorator-based middleware
# =============================================================================
@before_model
def old_style_decorator(state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None:
def old_style_decorator(state: AgentState[Any], runtime: Runtime[Any]) -> dict[str, Any] | None:
"""Decorator middleware - old pattern."""
return None

View File

@@ -7,8 +7,10 @@ Expected errors:
1. TypedDict "UserContext" has no key "session_id" - accessing wrong context field
2. Argument incompatible with supertype - mismatched ModelRequest type
3. Cannot infer value of type parameter - middleware/context_schema mismatch
4. "AnalysisResult" has no attribute "summary" - accessing wrong response field
5. Handler returns wrong ResponseT type
4. (No longer an error) Backwards compatible middleware with context_schema is OK
because ContextT defaults to Any, which is compatible with any context_schema
5. "AnalysisResult" has no attribute "summary" - accessing wrong response field
6. Handler returns wrong ResponseT type
"""
from __future__ import annotations
@@ -107,7 +109,9 @@ def test_mismatched_context_schema() -> None:
# =============================================================================
# ERROR 4: Backwards compatible middleware with typed context_schema
# NOTE 4: Backwards compatible middleware with typed context_schema
# This is NOT a type error because ContextT defaults to Any, which is
# compatible with any context_schema.
# =============================================================================
class BackwardsCompatibleMiddleware(AgentMiddleware):
def wrap_model_call(
@@ -119,10 +123,10 @@ class BackwardsCompatibleMiddleware(AgentMiddleware):
def test_backwards_compat_with_context_schema() -> None:
# TYPE ERROR: BackwardsCompatibleMiddleware is AgentMiddleware[..., None]
# but context_schema=UserContext expects AgentMiddleware[..., UserContext]
# BackwardsCompatibleMiddleware uses default type params (Any),
# so it is compatible with any context_schema including UserContext
fake_model = FakeToolCallingModel()
_agent = create_agent( # type: ignore[misc]
_agent = create_agent(
model=fake_model,
middleware=[BackwardsCompatibleMiddleware()],
context_schema=UserContext,

View File

@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any
import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from pydantic import BaseModel
from typing_extensions import TypedDict
@@ -31,13 +31,17 @@ from langchain.agents.middleware.types import (
ModelResponse,
ResponseT,
before_model,
wrap_model_call,
wrap_tool_call,
)
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.runtime import Runtime
from langgraph.types import Command
# =============================================================================
@@ -78,7 +82,7 @@ class SummaryResult(BaseModel):
class BackwardsCompatibleMiddleware(AgentMiddleware):
"""Middleware that doesn't specify type parameters - backwards compatible."""
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None:
def before_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> dict[str, Any] | None:
return None
def wrap_model_call(
@@ -103,7 +107,7 @@ class BackwardsCompatibleMiddleware2(AgentMiddleware):
@before_model
def backwards_compatible_decorator(
state: AgentState[Any], runtime: Runtime[None]
state: AgentState[Any], runtime: Runtime[Any]
) -> dict[str, Any] | None:
"""Decorator middleware without explicit type parameters."""
return None
@@ -237,7 +241,7 @@ def fake_model() -> GenericFakeChatModel:
def test_create_agent_no_context_schema(fake_model: GenericFakeChatModel) -> None:
"""Backwards compatible: No context_schema means ContextT=None."""
"""Backwards compatible: No context_schema means ContextT=Any."""
agent: CompiledStateGraph[Any, None, Any, Any] = create_agent(
model=fake_model,
middleware=[
@@ -441,3 +445,148 @@ def test_model_response_backwards_compatible() -> None:
)
assert response.structured_response is None
# =============================================================================
# 9. ASYNC DECORATOR VARIANTS FOR wrap_model_call AND wrap_tool_call
# =============================================================================
@wrap_model_call
async def async_wrap_model_retry(
request: ModelRequest[UserContext],
handler: Callable[[ModelRequest[UserContext]], Awaitable[ModelResponse[Any]]],
) -> ModelResponse[Any] | AIMessage:
"""Async wrap_model_call decorator should type-check correctly."""
for attempt in range(3):
try:
return await handler(request)
except Exception:
if attempt == 2:
raise
return await handler(request)
@wrap_model_call
async def async_wrap_model_unparameterized(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse | AIMessage:
"""Async wrap_model_call with unparameterized types (backwards compat)."""
return await handler(request)
@wrap_tool_call
async def async_wrap_tool_retry(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
"""Async wrap_tool_call decorator should type-check correctly."""
for attempt in range(3):
try:
return await handler(request)
except Exception:
if attempt == 2:
raise
return await handler(request)
def test_async_wrap_model_call_decorator(fake_model: GenericFakeChatModel) -> None:
"""Async @wrap_model_call decorator produces valid AgentMiddleware."""
agent = create_agent(
model=fake_model,
middleware=[async_wrap_model_retry],
context_schema=UserContext,
)
assert agent is not None
def test_async_wrap_model_call_unparameterized(fake_model: GenericFakeChatModel) -> None:
"""Async @wrap_model_call with unparameterized types works."""
agent = create_agent(
model=fake_model,
middleware=[async_wrap_model_unparameterized],
)
assert agent is not None
def test_async_wrap_tool_call_decorator(fake_model: GenericFakeChatModel) -> None:
"""Async @wrap_tool_call decorator produces valid AgentMiddleware."""
agent = create_agent(
model=fake_model,
middleware=[async_wrap_tool_retry],
)
assert agent is not None
# Test sync decorators still work (regression check)
@wrap_model_call
def sync_wrap_model(
request: ModelRequest[UserContext],
handler: Callable[[ModelRequest[UserContext]], ModelResponse[Any]],
) -> ModelResponse[Any] | AIMessage:
"""Sync wrap_model_call should still type-check correctly."""
return handler(request)
@wrap_tool_call
def sync_wrap_tool(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
"""Sync wrap_tool_call should still type-check correctly."""
return handler(request)
def test_sync_wrap_model_call_decorator(fake_model: GenericFakeChatModel) -> None:
"""Sync @wrap_model_call decorator still works."""
agent = create_agent(
model=fake_model,
middleware=[sync_wrap_model],
context_schema=UserContext,
)
assert agent is not None
def test_sync_wrap_tool_call_decorator(fake_model: GenericFakeChatModel) -> None:
"""Sync @wrap_tool_call decorator still works."""
agent = create_agent(
model=fake_model,
middleware=[sync_wrap_tool],
)
assert agent is not None
# Test with func=None pattern (parenthesized decorator)
@wrap_model_call()
async def async_wrap_model_with_parens(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelResponse | AIMessage:
"""Async wrap_model_call with parentheses should type-check correctly."""
return await handler(request)
@wrap_tool_call()
async def async_wrap_tool_with_parens(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
"""Async wrap_tool_call with parentheses should type-check correctly."""
return await handler(request)
def test_async_wrap_model_call_with_parens(fake_model: GenericFakeChatModel) -> None:
"""Async @wrap_model_call() with parens produces valid AgentMiddleware."""
agent = create_agent(
model=fake_model,
middleware=[async_wrap_model_with_parens],
)
assert agent is not None
def test_async_wrap_tool_call_with_parens(fake_model: GenericFakeChatModel) -> None:
"""Async @wrap_tool_call() with parens produces valid AgentMiddleware."""
agent = create_agent(
model=fake_model,
middleware=[async_wrap_tool_with_parens],
)
assert agent is not None