From dde2012b832c9d370a2ea76fd18a79624a182a4c Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 5 Feb 2026 07:41:27 -0500 Subject: [PATCH] feat: threading context through `create_agent` flows + middleware (#34978) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes https://github.com/langchain-ai/langchain/issues/33956 * Making `ModelRequest` generic on `ContextT` and `ResponseT` so that we can thread type information through to `wrap_model_call` * Making builtin middlewares generic on `ContextT` and `ResponseT` so their context and response types can be inferred from the `create_agent` signature See new tests that verify backwards compatibility (for cases where folks use custom middleware that wasn't parametrized). This fixes: 1. Lack of access to context and response types in `wrap_model_call` 2. Lack of cohesion between middleware context + response types with those specified in `create_agent` See examples below: ### Type-safe context and response access ```python class MyMiddleware(AgentMiddleware[AgentState[AnalysisResult], UserContext, AnalysisResult]): def wrap_model_call( self, request: ModelRequest[UserContext], handler: Callable[[ModelRequest[UserContext]], ModelResponse[AnalysisResult]], ) -> ModelResponse[AnalysisResult]: # ✅ Now type-safe: IDE knows user_id exists and is str user_id: str = request.runtime.context["user_id"] # ❌ mypy error: "session_id" doesn't exist on UserContext request.runtime.context["session_id"] response = handler(request) if response.structured_response is not None: # ✅ Now type-safe: IDE knows sentiment exists and is str sentiment: str = response.structured_response.sentiment # ❌ mypy error: "summary" doesn't exist on AnalysisResult response.structured_response.summary return response ``` ### Mismatched middleware/schema caught at `create_agent` ```python class SessionMiddleware(AgentMiddleware[AgentState[Any], SessionContext, Any]): ... # ❌ mypy error: SessionMiddleware expects SessionContext, not UserContext create_agent( model=model, middleware=[SessionMiddleware()], context_schema=UserContext, # mismatch! ) class AnalysisMiddleware(AgentMiddleware[AgentState[AnalysisResult], ContextT, AnalysisResult]): ... # ❌ mypy error: AnalysisMiddleware expects AnalysisResult, not SummaryResult create_agent( model=model, middleware=[AnalysisMiddleware()], response_format=SummaryResult, # mismatch! ) ``` --- libs/langchain_v1/langchain/agents/factory.py | 56 +-- .../agents/middleware/context_editing.py | 18 +- .../agents/middleware/file_search.py | 4 +- .../agents/middleware/human_in_the_loop.py | 10 +- .../agents/middleware/model_call_limit.py | 29 +- .../agents/middleware/model_fallback.py | 19 +- .../agents/middleware/model_retry.py | 27 +- .../langchain/agents/middleware/pii.py | 18 +- .../langchain/agents/middleware/shell_tool.py | 36 +- .../agents/middleware/summarization.py | 10 +- .../langchain/agents/middleware/todo.py | 35 +- .../agents/middleware/tool_call_limit.py | 12 +- .../agents/middleware/tool_emulator.py | 6 +- .../langchain/agents/middleware/tool_retry.py | 4 +- .../agents/middleware/tool_selection.py | 28 +- .../langchain/agents/middleware/types.py | 96 ++-- libs/langchain_v1/pyproject.toml | 7 +- .../agents/middleware_typing/__init__.py | 0 .../test_middleware_backwards_compat.py | 275 +++++++++++ .../test_middleware_type_errors.py | 201 ++++++++ .../test_middleware_typing.py | 443 ++++++++++++++++++ 21 files changed, 1167 insertions(+), 167 deletions(-) create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware_typing/__init__.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_backwards_compat.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_type_errors.py create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 7f9979fa8fc..314d9b6ee71 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -26,6 +26,7 @@ from typing_extensions import NotRequired, Required, TypedDict from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, + ContextT, JumpTo, ModelRequest, ModelResponse, @@ -57,7 +58,6 @@ if TYPE_CHECKING: from langgraph.runtime import Runtime from langgraph.store.base import BaseStore from langgraph.types import Checkpointer - from langgraph.typing import ContextT from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper @@ -112,13 +112,13 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp def _chain_model_call_handlers( handlers: Sequence[ Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], ModelResponse | AIMessage, ] ], ) -> ( Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], ModelResponse, ] | None @@ -168,8 +168,8 @@ def _chain_model_call_handlers( single_handler = handlers[0] def normalized_single( - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse], ) -> ModelResponse: result = single_handler(request, handler) return _normalize_to_model_response(result) @@ -178,25 +178,25 @@ def _chain_model_call_handlers( def compose_two( outer: Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], ModelResponse | AIMessage, ], inner: Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], ModelResponse | AIMessage, ], ) -> Callable[ - [ModelRequest, Callable[[ModelRequest], ModelResponse]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], ModelResponse]], ModelResponse, ]: """Compose two handlers where outer wraps inner.""" def composed( - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse], ) -> ModelResponse: # Create a wrapper that calls inner with the base handler and normalizes - def inner_handler(req: ModelRequest) -> ModelResponse: + def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse: inner_result = inner(req, handler) return _normalize_to_model_response(inner_result) @@ -213,8 +213,8 @@ def _chain_model_call_handlers( # Wrap to ensure final return type is exactly ModelResponse def final_normalized( - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse], ) -> ModelResponse: # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes final_result = result(request, handler) @@ -226,13 +226,13 @@ def _chain_model_call_handlers( def _chain_async_model_call_handlers( handlers: Sequence[ Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], Awaitable[ModelResponse | AIMessage], ] ], ) -> ( Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], Awaitable[ModelResponse], ] | None @@ -255,8 +255,8 @@ def _chain_async_model_call_handlers( single_handler = handlers[0] async def normalized_single( - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]], ) -> ModelResponse: result = await single_handler(request, handler) return _normalize_to_model_response(result) @@ -265,25 +265,25 @@ def _chain_async_model_call_handlers( def compose_two( outer: Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], Awaitable[ModelResponse | AIMessage], ], inner: Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], Awaitable[ModelResponse | AIMessage], ], ) -> Callable[ - [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], + [ModelRequest[ContextT], Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]]], Awaitable[ModelResponse], ]: """Compose two async handlers where outer wraps inner.""" async def composed( - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]], ) -> ModelResponse: # Create a wrapper that calls inner with the base handler and normalizes - async def inner_handler(req: ModelRequest) -> ModelResponse: + async def inner_handler(req: ModelRequest[ContextT]) -> ModelResponse: inner_result = await inner(req, handler) return _normalize_to_model_response(inner_result) @@ -300,8 +300,8 @@ def _chain_async_model_call_handlers( # Wrap to ensure final return type is exactly ModelResponse async def final_normalized( - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]], ) -> ModelResponse: # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes final_result = await result(request, handler) @@ -1015,7 +1015,7 @@ def create_agent( return {"messages": [output]} def _get_bound_model( - request: ModelRequest, + request: ModelRequest[ContextT], ) -> tuple[Runnable[Any, Any], ResponseFormat[Any] | None]: """Get the model with appropriate tool bindings. @@ -1138,7 +1138,7 @@ def create_agent( ) return request.model.bind(**request.model_settings), None - def _execute_model_sync(request: ModelRequest) -> ModelResponse: + def _execute_model_sync(request: ModelRequest[ContextT]) -> ModelResponse: """Execute model and return response. This is the core model execution logic wrapped by `wrap_model_call` handlers. @@ -1192,7 +1192,7 @@ def create_agent( return state_updates - async def _execute_model_async(request: ModelRequest) -> ModelResponse: + async def _execute_model_async(request: ModelRequest[ContextT]) -> ModelResponse: """Execute model asynchronously and return response. This is the core async model execution logic wrapped by `wrap_model_call` diff --git a/libs/langchain_v1/langchain/agents/middleware/context_editing.py b/libs/langchain_v1/langchain/agents/middleware/context_editing.py index a7a0a8e6032..b70c3287cd9 100644 --- a/libs/langchain_v1/langchain/agents/middleware/context_editing.py +++ b/libs/langchain_v1/langchain/agents/middleware/context_editing.py @@ -25,9 +25,11 @@ from typing_extensions import Protocol from langchain.agents.middleware.types import ( AgentMiddleware, - ModelCallResult, + AgentState, + ContextT, ModelRequest, ModelResponse, + ResponseT, ) DEFAULT_TOOL_PLACEHOLDER = "[cleared]" @@ -182,7 +184,7 @@ class ClearToolUsesEdit(ContextEdit): ) -class ContextEditingMiddleware(AgentMiddleware): +class ContextEditingMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Automatically prune tool results to manage context size. The middleware applies a sequence of edits when the total input token count exceeds @@ -217,9 +219,9 @@ class ContextEditingMiddleware(AgentMiddleware): def wrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: """Apply context edits before invoking the model via handler. Args: @@ -254,9 +256,9 @@ class ContextEditingMiddleware(AgentMiddleware): async def awrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: """Apply context edits before invoking the model via handler. Args: diff --git a/libs/langchain_v1/langchain/agents/middleware/file_search.py b/libs/langchain_v1/langchain/agents/middleware/file_search.py index b6dcb775412..ffbf3c61eea 100644 --- a/libs/langchain_v1/langchain/agents/middleware/file_search.py +++ b/libs/langchain_v1/langchain/agents/middleware/file_search.py @@ -17,7 +17,7 @@ from typing import Literal from langchain_core.tools import tool -from langchain.agents.middleware.types import AgentMiddleware +from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT def _expand_include_patterns(pattern: str) -> list[str] | None: @@ -84,7 +84,7 @@ def _match_include_pattern(basename: str, pattern: str) -> bool: return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded) -class FilesystemFileSearchMiddleware(AgentMiddleware): +class FilesystemFileSearchMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Provides Glob and Grep search over filesystem files. This middleware adds two tools that search through local filesystem: diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index 7fbb6e964df..acb81c58c6d 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -7,7 +7,13 @@ from langgraph.runtime import Runtime from langgraph.types import interrupt from typing_extensions import NotRequired, TypedDict -from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, + StateT, +) class Action(TypedDict): @@ -158,7 +164,7 @@ class InterruptOnConfig(TypedDict): """JSON schema for the args associated with the action, if edits are allowed.""" -class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]): +class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT, ResponseT]): """Human in the loop middleware.""" def __init__( diff --git a/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py b/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py index 3f1295260fb..fb231ee5fbb 100644 --- a/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py +++ b/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py @@ -11,7 +11,9 @@ from typing_extensions import NotRequired, override from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, + ContextT, PrivateStateAttr, + ResponseT, hook_config, ) @@ -19,10 +21,13 @@ if TYPE_CHECKING: from langgraph.runtime import Runtime -class ModelCallLimitState(AgentState[Any]): +class ModelCallLimitState(AgentState[ResponseT]): """State schema for `ModelCallLimitMiddleware`. Extends `AgentState` with model call tracking fields. + + Type Parameters: + ResponseT: The type of the structured response. Defaults to `Any`. """ thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]] @@ -86,7 +91,9 @@ class ModelCallLimitExceededError(Exception): super().__init__(msg) -class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): +class ModelCallLimitMiddleware( + AgentMiddleware[ModelCallLimitState[ResponseT], ContextT, ResponseT] +): """Tracks model call counts and enforces limits. This middleware monitors the number of model calls made during agent execution @@ -114,7 +121,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): ``` """ - state_schema = ModelCallLimitState + state_schema = ModelCallLimitState # type: ignore[assignment] def __init__( self, @@ -158,7 +165,9 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): @hook_config(can_jump_to=["end"]) @override - def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: + def before_model( + self, state: ModelCallLimitState[ResponseT], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Check model call limits before making a model call. Args: @@ -203,8 +212,8 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): @hook_config(can_jump_to=["end"]) async def abefore_model( self, - state: ModelCallLimitState, - runtime: Runtime, + state: ModelCallLimitState[ResponseT], + runtime: Runtime[ContextT], ) -> dict[str, Any] | None: """Async check model call limits before making a model call. @@ -224,7 +233,9 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): return self.before_model(state, runtime) @override - def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: + def after_model( + self, state: ModelCallLimitState[ResponseT], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Increment model call counts after a model call. Args: @@ -241,8 +252,8 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]): async def aafter_model( self, - state: ModelCallLimitState, - runtime: Runtime, + state: ModelCallLimitState[ResponseT], + runtime: Runtime[ContextT], ) -> dict[str, Any] | None: """Async increment model call counts after a model call. diff --git a/libs/langchain_v1/langchain/agents/middleware/model_fallback.py b/libs/langchain_v1/langchain/agents/middleware/model_fallback.py index 7bae4ce6ea5..66a202350ad 100644 --- a/libs/langchain_v1/langchain/agents/middleware/model_fallback.py +++ b/libs/langchain_v1/langchain/agents/middleware/model_fallback.py @@ -6,9 +6,11 @@ from typing import TYPE_CHECKING from langchain.agents.middleware.types import ( AgentMiddleware, - ModelCallResult, + AgentState, + ContextT, ModelRequest, ModelResponse, + ResponseT, ) from langchain.chat_models import init_chat_model @@ -16,9 +18,10 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable from langchain_core.language_models.chat_models import BaseChatModel + from langchain_core.messages import AIMessage -class ModelFallbackMiddleware(AgentMiddleware): +class ModelFallbackMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Automatic fallback to alternative models on errors. Retries failed model calls with alternative models in sequence until @@ -68,9 +71,9 @@ class ModelFallbackMiddleware(AgentMiddleware): def wrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: """Try fallback models in sequence on errors. Args: @@ -102,9 +105,9 @@ class ModelFallbackMiddleware(AgentMiddleware): async def awrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: """Try fallback models in sequence on errors (async version). Args: diff --git a/libs/langchain_v1/langchain/agents/middleware/model_retry.py b/libs/langchain_v1/langchain/agents/middleware/model_retry.py index 86e971544e2..81a54a2ebb6 100644 --- a/libs/langchain_v1/langchain/agents/middleware/model_retry.py +++ b/libs/langchain_v1/langchain/agents/middleware/model_retry.py @@ -15,15 +15,20 @@ from langchain.agents.middleware._retry import ( should_retry_exception, validate_retry_params, ) -from langchain.agents.middleware.types import AgentMiddleware, ModelResponse +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) if TYPE_CHECKING: from collections.abc import Awaitable, Callable - from langchain.agents.middleware.types import ModelRequest - -class ModelRetryMiddleware(AgentMiddleware): +class ModelRetryMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Middleware that automatically retries failed model calls with configurable backoff. Supports retrying on specific exceptions and exponential backoff. @@ -182,7 +187,7 @@ class ModelRetryMiddleware(AgentMiddleware): ) return AIMessage(content=content) - def _handle_failure(self, exc: Exception, attempts_made: int) -> ModelResponse: + def _handle_failure(self, exc: Exception, attempts_made: int) -> ModelResponse[ResponseT]: """Handle failure when all retries are exhausted. Args: @@ -208,9 +213,9 @@ class ModelRetryMiddleware(AgentMiddleware): def wrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelResponse | AIMessage: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: """Intercept model execution and retry on failure. Args: @@ -258,9 +263,9 @@ class ModelRetryMiddleware(AgentMiddleware): async def awrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelResponse | AIMessage: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: """Intercept and control async model execution with retry logic. Args: diff --git a/libs/langchain_v1/langchain/agents/middleware/pii.py b/libs/langchain_v1/langchain/agents/middleware/pii.py index 06b5a764e69..d15e682c259 100644 --- a/libs/langchain_v1/langchain/agents/middleware/pii.py +++ b/libs/langchain_v1/langchain/agents/middleware/pii.py @@ -19,7 +19,13 @@ from langchain.agents.middleware._redaction import ( detect_mac_address, detect_url, ) -from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ResponseT, + hook_config, +) if TYPE_CHECKING: from collections.abc import Callable @@ -27,7 +33,7 @@ if TYPE_CHECKING: from langgraph.runtime import Runtime -class PIIMiddleware(AgentMiddleware): +class PIIMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Detect and handle Personally Identifiable Information (PII) in conversations. This middleware detects common PII types and applies configurable strategies @@ -165,7 +171,7 @@ class PIIMiddleware(AgentMiddleware): def before_model( self, state: AgentState[Any], - runtime: Runtime, + runtime: Runtime[ContextT], ) -> dict[str, Any] | None: """Check user messages and tool results for PII before model invocation. @@ -260,7 +266,7 @@ class PIIMiddleware(AgentMiddleware): async def abefore_model( self, state: AgentState[Any], - runtime: Runtime, + runtime: Runtime[ContextT], ) -> dict[str, Any] | None: """Async check user messages and tool results for PII before model invocation. @@ -281,7 +287,7 @@ class PIIMiddleware(AgentMiddleware): def after_model( self, state: AgentState[Any], - runtime: Runtime, + runtime: Runtime[ContextT], ) -> dict[str, Any] | None: """Check AI messages for PII after model invocation. @@ -340,7 +346,7 @@ class PIIMiddleware(AgentMiddleware): async def aafter_model( self, state: AgentState[Any], - runtime: Runtime, + runtime: Runtime[ContextT], ) -> dict[str, Any] | None: """Async check AI messages for PII after model invocation. diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index 647b68837dc..93fd978ca4f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -38,7 +38,13 @@ from langchain.agents.middleware._redaction import ( RedactionRule, ResolvedRedactionRule, ) -from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + PrivateStateAttr, + ResponseT, +) from langchain.tools import ToolRuntime, tool if TYPE_CHECKING: @@ -91,8 +97,12 @@ class _SessionResources: ) -class ShellToolState(AgentState[Any]): - """Agent state extension for tracking shell session resources.""" +class ShellToolState(AgentState[ResponseT]): + """Agent state extension for tracking shell session resources. + + Type Parameters: + ResponseT: The type of the structured response. Defaults to `Any`. + """ shell_session_resources: NotRequired[ Annotated[_SessionResources | None, UntrackedValue, PrivateStateAttr] @@ -476,7 +486,7 @@ class _ShellToolInput(BaseModel): return self -class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): +class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, ResponseT]): """Middleware that registers a persistent shell tool for agents. The middleware exposes a single long-lived shell session. Use the execution policy @@ -493,7 +503,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): When no policy is provided the middleware defaults to `HostExecutionPolicy`. """ - state_schema = ShellToolState + state_schema = ShellToolState # type: ignore[assignment] def __init__( self, @@ -615,7 +625,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): return normalized @override - def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: + def before_agent( + self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Start the shell session and run startup commands. Args: @@ -628,7 +640,9 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): resources = self._get_or_create_resources(state) return {"shell_session_resources": resources} - async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: + async def abefore_agent( + self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Async start the shell session and run startup commands. Args: @@ -641,7 +655,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): return await run_in_executor(None, self.before_agent, state, runtime) @override - def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: + def after_agent(self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT]) -> None: """Run shutdown commands and release resources when an agent completes.""" resources = state.get("shell_session_resources") if not isinstance(resources, _SessionResources): @@ -652,11 +666,13 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]): finally: resources.finalizer() - async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None: + async def aafter_agent( + self, state: ShellToolState[ResponseT], runtime: Runtime[ContextT] + ) -> None: """Async run shutdown commands and release resources when an agent completes.""" return self.after_agent(state, runtime) - def _get_or_create_resources(self, state: ShellToolState) -> _SessionResources: + def _get_or_create_resources(self, state: ShellToolState[ResponseT]) -> _SessionResources: """Get existing resources from state or create new ones if they don't exist. This method enables resumability by checking if resources already exist in the state diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 8e9edca3e60..cd5a38b0b71 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -25,7 +25,7 @@ from langgraph.graph.message import ( from langgraph.runtime import Runtime from typing_extensions import override -from langchain.agents.middleware.types import AgentMiddleware, AgentState +from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT from langchain.chat_models import BaseChatModel, init_chat_model TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int] @@ -148,7 +148,7 @@ def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter: return count_tokens_approximately -class SummarizationMiddleware(AgentMiddleware): +class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Summarizes conversation history when token limits are approached. This middleware monitors message token counts and automatically summarizes older @@ -284,7 +284,9 @@ class SummarizationMiddleware(AgentMiddleware): raise ValueError(msg) @override - def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None: + def before_model( + self, state: AgentState[Any], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Process messages before model invocation, potentially triggering summarization. Args: @@ -321,7 +323,7 @@ class SummarizationMiddleware(AgentMiddleware): @override async def abefore_model( - self, state: AgentState[Any], runtime: Runtime + self, state: AgentState[Any], runtime: Runtime[ContextT] ) -> dict[str, Any] | None: """Process messages before model invocation, potentially triggering summarization. diff --git a/libs/langchain_v1/langchain/agents/middleware/todo.py b/libs/langchain_v1/langchain/agents/middleware/todo.py index 1f1d0e9e57f..ba826d55396 100644 --- a/libs/langchain_v1/langchain/agents/middleware/todo.py +++ b/libs/langchain_v1/langchain/agents/middleware/todo.py @@ -17,10 +17,11 @@ from typing_extensions import NotRequired, TypedDict, override from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, - ModelCallResult, + ContextT, ModelRequest, ModelResponse, OmitFromInput, + ResponseT, ) from langchain.tools import InjectedToolCallId @@ -35,8 +36,12 @@ class Todo(TypedDict): """The current status of the todo item.""" -class PlanningState(AgentState[Any]): - """State schema for the todo middleware.""" +class PlanningState(AgentState[ResponseT]): + """State schema for the todo middleware. + + Type Parameters: + ResponseT: The type of the structured response. Defaults to `Any`. + """ todos: Annotated[NotRequired[list[Todo]], OmitFromInput] """List of todo items for tracking task progress.""" @@ -130,7 +135,7 @@ def write_todos( ) -class TodoListMiddleware(AgentMiddleware): +class TodoListMiddleware(AgentMiddleware[PlanningState[ResponseT], ContextT, ResponseT]): """Middleware that provides todo list management capabilities to agents. This middleware adds a `write_todos` tool that allows agents to create and manage @@ -157,7 +162,7 @@ class TodoListMiddleware(AgentMiddleware): ``` """ - state_schema = PlanningState + state_schema = PlanningState # type: ignore[assignment] def __init__( self, @@ -195,9 +200,9 @@ class TodoListMiddleware(AgentMiddleware): def wrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: """Update the system message to include the todo system prompt. Args: @@ -222,9 +227,9 @@ class TodoListMiddleware(AgentMiddleware): async def awrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: """Update the system message to include the todo system prompt. Args: @@ -248,7 +253,9 @@ class TodoListMiddleware(AgentMiddleware): return await handler(request.override(system_message=new_system_message)) @override - def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None: + def after_model( + self, state: PlanningState[ResponseT], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Check for parallel write_todos tool calls and return errors if detected. The todo list is designed to be updated at most once per model turn. Since @@ -298,7 +305,9 @@ class TodoListMiddleware(AgentMiddleware): return None @override - async def aafter_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None: + async def aafter_model( + self, state: PlanningState[ResponseT], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Check for parallel write_todos tool calls and return errors if detected. Async version of `after_model`. The todo list is designed to be updated at diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py b/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py index 3df165be5c7..503205866ad 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal from langchain_core.messages import AIMessage, ToolCall, ToolMessage from langgraph.channels.untracked_value import UntrackedValue @@ -32,7 +32,7 @@ ExitBehavior = Literal["continue", "error", "end"] """ -class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]): +class ToolCallLimitState(AgentState[ResponseT]): """State schema for `ToolCallLimitMiddleware`. Extends `AgentState` with tool call tracking fields. @@ -40,6 +40,9 @@ class ToolCallLimitState(AgentState[ResponseT], Generic[ResponseT]): The count fields are dictionaries mapping tool names to execution counts. This allows multiple middleware instances to track different tools independently. The special key `'__all__'` is used for tracking all tool calls globally. + + Type Parameters: + ResponseT: The type of the structured response. Defaults to `Any`. """ thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]] @@ -134,10 +137,7 @@ class ToolCallLimitExceededError(Exception): super().__init__(msg) -class ToolCallLimitMiddleware( - AgentMiddleware[ToolCallLimitState[ResponseT], ContextT], - Generic[ResponseT, ContextT], -): +class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState[ResponseT], ContextT, ResponseT]): """Track tool call counts and enforces limits during agent execution. This middleware monitors the number of tool calls made and can terminate or diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py index 967ece03611..c5605f40af0 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py @@ -2,12 +2,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage, ToolMessage -from langchain.agents.middleware.types import AgentMiddleware +from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT from langchain.chat_models.base import init_chat_model if TYPE_CHECKING: @@ -19,7 +19,7 @@ if TYPE_CHECKING: from langchain.tools import BaseTool -class LLMToolEmulator(AgentMiddleware): +class LLMToolEmulator(AgentMiddleware[AgentState[Any], ContextT], Generic[ContextT]): """Emulates specified tools using an LLM instead of executing them. This middleware allows selective emulation of tools for testing purposes. diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py index c162b61c96c..c6fa238eabf 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py @@ -16,7 +16,7 @@ from langchain.agents.middleware._retry import ( should_retry_exception, validate_retry_params, ) -from langchain.agents.middleware.types import AgentMiddleware +from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, ResponseT if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -27,7 +27,7 @@ if TYPE_CHECKING: from langchain.tools import BaseTool -class ToolRetryMiddleware(AgentMiddleware): +class ToolRetryMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Middleware that automatically retries failed tool calls with configurable backoff. Supports retrying on specific exceptions and exponential backoff. diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_selection.py b/libs/langchain_v1/langchain/agents/middleware/tool_selection.py index cbe926c15fd..f046ac8adac 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_selection.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_selection.py @@ -7,15 +7,17 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Annotated, Any, Literal, Union from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessage, HumanMessage from pydantic import Field, TypeAdapter from typing_extensions import TypedDict from langchain.agents.middleware.types import ( AgentMiddleware, - ModelCallResult, + AgentState, + ContextT, ModelRequest, ModelResponse, + ResponseT, ) from langchain.chat_models.base import init_chat_model @@ -88,7 +90,7 @@ def _render_tool_list(tools: list[BaseTool]) -> str: return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools) -class LLMToolSelectorMiddleware(AgentMiddleware): +class LLMToolSelectorMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): """Uses an LLM to select relevant tools before calling the main model. When an agent has many tools available, this middleware filters them down @@ -153,7 +155,9 @@ class LLMToolSelectorMiddleware(AgentMiddleware): else: self.model = init_chat_model(model) - def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest | None: + def _prepare_selection_request( + self, request: ModelRequest[ContextT] + ) -> _SelectionRequest | None: """Prepare inputs for tool selection. Args: @@ -230,8 +234,8 @@ class LLMToolSelectorMiddleware(AgentMiddleware): response: dict[str, Any], available_tools: list[BaseTool], valid_tool_names: list[str], - request: ModelRequest, - ) -> ModelRequest: + request: ModelRequest[ContextT], + ) -> ModelRequest[ContextT]: """Process the selection response and return filtered `ModelRequest`.""" selected_tool_names: list[str] = [] invalid_tool_selections = [] @@ -269,9 +273,9 @@ class LLMToolSelectorMiddleware(AgentMiddleware): def wrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: """Filter tools based on LLM selection before invoking the model via handler. Args: @@ -312,9 +316,9 @@ class LLMToolSelectorMiddleware(AgentMiddleware): async def awrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: """Filter tools based on LLM selection before invoking the model via handler. Args: diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index cc09f91b7e7..a48147a11ee 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -68,7 +68,7 @@ __all__ = [ JumpTo = Literal["tools", "model", "end"] """Destination to jump to when a middleware node returns.""" -ResponseT = TypeVar("ResponseT") +ResponseT = TypeVar("ResponseT", default=Any) class _ModelRequestOverrides(TypedDict, total=False): @@ -85,8 +85,12 @@ class _ModelRequestOverrides(TypedDict, total=False): @dataclass(init=False) -class ModelRequest: - """Model request information for the agent.""" +class ModelRequest(Generic[ContextT]): + """Model request information for the agent. + + Type Parameters: + ContextT: The type of the runtime context. Defaults to `None` if not specified. + """ model: BaseChatModel messages: list[AnyMessage] # excluding system message @@ -95,7 +99,7 @@ class ModelRequest: tools: list[BaseTool | dict[str, Any]] response_format: ResponseFormat[Any] | None state: AgentState[Any] - runtime: Runtime[ContextT] # type: ignore[valid-type] + runtime: Runtime[ContextT] model_settings: dict[str, Any] = field(default_factory=dict) def __init__( @@ -194,7 +198,7 @@ class ModelRequest: ) object.__setattr__(self, name, value) - def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest: + def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest[ContextT]: """Replace the request with a new request with the given overrides. Returns a new `ModelRequest` instance with the specified attributes replaced. @@ -264,22 +268,25 @@ class ModelRequest: @dataclass -class ModelResponse: +class ModelResponse(Generic[ResponseT]): """Response from model execution including messages and optional structured output. The result will usually contain a single `AIMessage`, but may include an additional `ToolMessage` if the model used a tool for structured output. + + Type Parameters: + ResponseT: The type of the structured response. Defaults to `Any` if not specified. """ result: list[BaseMessage] """List of messages from model execution.""" - structured_response: Any = None + structured_response: ResponseT | None = None """Parsed structured output if `response_format` was specified, `None` otherwise.""" # Type alias for middleware return type - allows returning either full response or just AIMessage -ModelCallResult: TypeAlias = ModelResponse | AIMessage +ModelCallResult: TypeAlias = "ModelResponse[ResponseT] | AIMessage" """`TypeAlias` for model call handler return value. Middleware can return either: @@ -340,11 +347,16 @@ class _DefaultAgentState(AgentState[Any]): """AgentMiddleware default state.""" -class AgentMiddleware(Generic[StateT, ContextT]): +class AgentMiddleware(Generic[StateT, ContextT, ResponseT]): """Base middleware class for an agent. Subclass this and implement any of the defined methods to customize agent behavior between steps in the main agent loop. + + Type Parameters: + StateT: The type of the agent state. Defaults to `AgentState[Any]`. + ContextT: The type of the runtime context. Defaults to `None`. + ResponseT: The type of the structured response. Defaults to `Any`. """ state_schema: type[StateT] = cast("type[StateT]", _DefaultAgentState) @@ -435,9 +447,9 @@ class AgentMiddleware(Generic[StateT, ContextT]): def wrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: """Intercept and control model execution via handler callback. Async version is `awrap_model_call` @@ -530,9 +542,9 @@ class AgentMiddleware(Generic[StateT, ContextT]): async def awrap_model_call( self, - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: """Intercept and control async model execution via handler callback. The handler callback executes the model request and returns a `ModelResponse`. @@ -770,13 +782,13 @@ class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # typ """Callable that returns a prompt string or SystemMessage given `ModelRequest`.""" def __call__( - self, request: ModelRequest + self, request: ModelRequest[ContextT] ) -> str | SystemMessage | Awaitable[str | SystemMessage]: """Generate a system prompt string or SystemMessage based on the request.""" ... -class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc] +class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT, ResponseT]): # type: ignore[misc] """Callable for model call interception with handler callback. Receives handler callback to execute model and returns `ModelResponse` or @@ -785,9 +797,9 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ def __call__( self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: """Intercept model execution via handler callback.""" ... @@ -1626,9 +1638,9 @@ def dynamic_prompt( async def async_wrapped( _self: AgentMiddleware[StateT, ContextT], - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]], + ) -> ModelResponse[Any] | AIMessage: prompt = await func(request) # type: ignore[misc] if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) @@ -1650,10 +1662,10 @@ def dynamic_prompt( def wrapped( _self: AgentMiddleware[StateT, ContextT], - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: - prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request) + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[Any]], + ) -> ModelResponse[Any] | AIMessage: + prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1662,11 +1674,11 @@ def dynamic_prompt( async def async_wrapped_from_sync( _self: AgentMiddleware[StateT, ContextT], - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]], + ) -> ModelResponse[Any] | AIMessage: # Delegate to sync function - prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request) + prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1693,7 +1705,7 @@ def dynamic_prompt( @overload def wrap_model_call( - func: _CallableReturningModelResponse[StateT, ContextT], + func: _CallableReturningModelResponse[StateT, ContextT, ResponseT], ) -> AgentMiddleware[StateT, ContextT]: ... @@ -1705,20 +1717,20 @@ def wrap_model_call( tools: list[BaseTool] | None = None, name: str | None = None, ) -> Callable[ - [_CallableReturningModelResponse[StateT, ContextT]], + [_CallableReturningModelResponse[StateT, ContextT, ResponseT]], AgentMiddleware[StateT, ContextT], ]: ... def wrap_model_call( - func: _CallableReturningModelResponse[StateT, ContextT] | None = None, + func: _CallableReturningModelResponse[StateT, ContextT, ResponseT] | None = None, *, state_schema: type[StateT] | None = None, tools: list[BaseTool] | None = None, name: str | None = None, ) -> ( Callable[ - [_CallableReturningModelResponse[StateT, ContextT]], + [_CallableReturningModelResponse[StateT, ContextT, ResponseT]], AgentMiddleware[StateT, ContextT], ] | AgentMiddleware[StateT, ContextT] @@ -1799,7 +1811,7 @@ def wrap_model_call( """ def decorator( - func: _CallableReturningModelResponse[StateT, ContextT], + func: _CallableReturningModelResponse[StateT, ContextT, ResponseT], ) -> AgentMiddleware[StateT, ContextT]: is_async = iscoroutinefunction(func) @@ -1807,9 +1819,9 @@ def wrap_model_call( async def async_wrapped( _self: AgentMiddleware[StateT, ContextT], - request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[ModelResponse]], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT] | AIMessage: return await func(request, handler) # type: ignore[misc, arg-type] middleware_name = name or cast( @@ -1828,9 +1840,9 @@ def wrap_model_call( def wrapped( _self: AgentMiddleware[StateT, ContextT], - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT] | AIMessage: return func(request, handler) middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware")) diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index b68c316a532..c050f4f7be9 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -107,7 +107,12 @@ line-length = 100 strict = true enable_error_code = "deprecated" warn_unreachable = true -exclude = ["tests/unit_tests/agents/*"] +exclude = [ + # Exclude agents tests except middleware_typing/ which has type-checked tests + "tests/unit_tests/agents/middleware/", + "tests/unit_tests/agents/specifications/", + "tests/unit_tests/agents/test_.*\\.py", +] # TODO: activate for 'strict' checking warn_return_any = false diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/__init__.py b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_backwards_compat.py b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_backwards_compat.py new file mode 100644 index 00000000000..ae85390c877 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_backwards_compat.py @@ -0,0 +1,275 @@ +"""Test backwards compatibility for middleware type parameters. + +This file verifies that middlewares written BEFORE the ResponseT change still work. +All patterns that were valid before should remain valid. + +Run type check: uv run --group typing mypy +Run tests: uv run --group test pytest -v +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage +from typing_extensions import TypedDict + +from langchain.agents import create_agent +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + before_model, +) + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from langgraph.runtime import Runtime + + +# ============================================================================= +# OLD PATTERN 1: Completely unparameterized AgentMiddleware +# This was the most common pattern for simple middlewares +# ============================================================================= +class OldStyleMiddleware1(AgentMiddleware): + """Middleware with no type parameters at all - most common old pattern.""" + + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None: + # Simple middleware that just logs or does something + return None + + def wrap_model_call( + self, + request: ModelRequest, # No type param + handler: Callable[[ModelRequest], ModelResponse], # No type params + ) -> ModelResponse: # No type param + return handler(request) + + +# ============================================================================= +# OLD PATTERN 2: AgentMiddleware with only 2 type parameters (StateT, ContextT) +# This was the pattern before ResponseT was added +# ============================================================================= +class OldStyleMiddleware2(AgentMiddleware[AgentState[Any], ContextT]): + """Middleware with 2 type params - the old signature before ResponseT.""" + + def wrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse], + ) -> ModelResponse: + return handler(request) + + +# ============================================================================= +# OLD PATTERN 3: Middleware with explicit None context +# ============================================================================= +class OldStyleMiddleware3(AgentMiddleware[AgentState[Any], None]): + """Middleware explicitly typed for no context.""" + + def wrap_model_call( + self, + request: ModelRequest[None], + handler: Callable[[ModelRequest[None]], ModelResponse], + ) -> ModelResponse: + return handler(request) + + +# ============================================================================= +# OLD PATTERN 4: Middleware with specific context type (2 params) +# ============================================================================= +class MyContext(TypedDict): + user_id: str + + +class OldStyleMiddleware4(AgentMiddleware[AgentState[Any], MyContext]): + """Middleware with specific context - old 2-param pattern.""" + + def wrap_model_call( + self, + request: ModelRequest[MyContext], + handler: Callable[[ModelRequest[MyContext]], ModelResponse], + ) -> ModelResponse: + # Access context fields + _user_id: str = request.runtime.context["user_id"] + return handler(request) + + +# ============================================================================= +# OLD PATTERN 5: Decorator-based middleware +# ============================================================================= +@before_model +def old_style_decorator(state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None: + """Decorator middleware - old pattern.""" + return None + + +# ============================================================================= +# OLD PATTERN 6: Async middleware (2 params) +# ============================================================================= +class OldStyleAsyncMiddleware(AgentMiddleware[AgentState[Any], ContextT]): + """Async middleware with old 2-param pattern.""" + + async def awrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse]], + ) -> ModelResponse: + return await handler(request) + + +# ============================================================================= +# OLD PATTERN 7: ModelResponse without type parameter +# ============================================================================= +class OldStyleModelResponseMiddleware(AgentMiddleware): + """Middleware using ModelResponse without type param.""" + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + response = handler(request) + # Access result - this always worked + _ = response.result + # structured_response was Any before, still works + _ = response.structured_response + return response + + +# ============================================================================= +# TESTS: Verify all old patterns still work at runtime +# ============================================================================= +@pytest.fixture +def fake_model() -> GenericFakeChatModel: + """Create a fake model for testing.""" + return GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + +def test_old_pattern_1_unparameterized(fake_model: GenericFakeChatModel) -> None: + """Old pattern 1: Completely unparameterized middleware.""" + agent = create_agent( + model=fake_model, + middleware=[OldStyleMiddleware1()], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +def test_old_pattern_2_two_params(fake_model: GenericFakeChatModel) -> None: + """Old pattern 2: AgentMiddleware[StateT, ContextT] - 2 params.""" + agent = create_agent( + model=fake_model, + middleware=[OldStyleMiddleware2()], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +def test_old_pattern_3_explicit_none(fake_model: GenericFakeChatModel) -> None: + """Old pattern 3: Explicit None context.""" + agent = create_agent( + model=fake_model, + middleware=[OldStyleMiddleware3()], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +def test_old_pattern_4_specific_context(fake_model: GenericFakeChatModel) -> None: + """Old pattern 4: Specific context type with 2 params.""" + agent = create_agent( + model=fake_model, + middleware=[OldStyleMiddleware4()], + context_schema=MyContext, + ) + result = agent.invoke( + {"messages": [HumanMessage(content="hi")]}, + context={"user_id": "test-user"}, + ) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +def test_old_pattern_5_decorator(fake_model: GenericFakeChatModel) -> None: + """Old pattern 5: Decorator-based middleware.""" + agent = create_agent( + model=fake_model, + middleware=[old_style_decorator], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +async def test_old_pattern_6_async(fake_model: GenericFakeChatModel) -> None: + """Old pattern 6: Async middleware with 2 params.""" + agent = create_agent( + model=fake_model, + middleware=[OldStyleAsyncMiddleware()], + ) + result = await agent.ainvoke({"messages": [HumanMessage(content="hi")]}) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +def test_old_pattern_7_model_response_unparameterized( + fake_model: GenericFakeChatModel, +) -> None: + """Old pattern 7: ModelResponse without type parameter.""" + agent = create_agent( + model=fake_model, + middleware=[OldStyleModelResponseMiddleware()], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +def test_multiple_old_style_middlewares(fake_model: GenericFakeChatModel) -> None: + """Multiple old-style middlewares can be combined.""" + agent = create_agent( + model=fake_model, + middleware=[ + OldStyleMiddleware1(), + OldStyleMiddleware2(), + OldStyleMiddleware3(), + old_style_decorator, + OldStyleModelResponseMiddleware(), + ], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + assert "messages" in result + assert len(result["messages"]) >= 1 + + +def test_model_response_backwards_compat() -> None: + """ModelResponse can be instantiated without type params.""" + # Old way - no type param + response = ModelResponse(result=[AIMessage(content="test")]) + assert response.structured_response is None + + # Old way - accessing fields + response2 = ModelResponse( + result=[AIMessage(content="test")], + structured_response={"key": "value"}, + ) + assert response2.structured_response == {"key": "value"} + + +def test_model_request_backwards_compat() -> None: + """ModelRequest can be instantiated without type params.""" + # Old way - no type param + request = ModelRequest( + model=None, # type: ignore[arg-type] + messages=[HumanMessage(content="test")], + ) + assert len(request.messages) == 1 diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_type_errors.py b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_type_errors.py new file mode 100644 index 00000000000..e89943e5c9b --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_type_errors.py @@ -0,0 +1,201 @@ +"""Demonstrate type errors that mypy catches for ContextT and ResponseT mismatches. + +This file contains intentional type errors to demonstrate that mypy catches them. +Run: uv run --group typing mypy + +Expected errors: +1. TypedDict "UserContext" has no key "session_id" - accessing wrong context field +2. Argument incompatible with supertype - mismatched ModelRequest type +3. Cannot infer value of type parameter - middleware/context_schema mismatch +4. "AnalysisResult" has no attribute "summary" - accessing wrong response field +5. Handler returns wrong ResponseT type +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel +from typing_extensions import TypedDict + +from langchain.agents import create_agent +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, +) +from tests.unit_tests.agents.model import FakeToolCallingModel + +if TYPE_CHECKING: + from collections.abc import Callable + + +# ============================================================================= +# Context and Response schemas +# ============================================================================= +class UserContext(TypedDict): + user_id: str + user_name: str + + +class SessionContext(TypedDict): + session_id: str + expires_at: int + + +class AnalysisResult(BaseModel): + sentiment: str + confidence: float + + +class SummaryResult(BaseModel): + summary: str + key_points: list[str] + + +# ============================================================================= +# ERROR 1: Using wrong context fields +# ============================================================================= +class WrongContextFieldsMiddleware(AgentMiddleware[AgentState[Any], UserContext, Any]): + def wrap_model_call( + self, + request: ModelRequest[UserContext], + handler: Callable[[ModelRequest[UserContext]], ModelResponse[Any]], + ) -> ModelResponse[Any]: + # TYPE ERROR: 'session_id' doesn't exist on UserContext + session_id: str = request.runtime.context["session_id"] # type: ignore[typeddict-item] + _ = session_id + return handler(request) + + +# ============================================================================= +# ERROR 2: Mismatched ModelRequest type parameter in method signature +# ============================================================================= +class MismatchedRequestMiddleware(AgentMiddleware[AgentState[Any], UserContext, Any]): + def wrap_model_call( # type: ignore[override] + self, + # TYPE ERROR: Should be ModelRequest[UserContext], not SessionContext + request: ModelRequest[SessionContext], + handler: Callable[[ModelRequest[SessionContext]], ModelResponse[Any]], + ) -> ModelResponse[Any]: + return handler(request) + + +# ============================================================================= +# ERROR 3: Middleware ContextT doesn't match context_schema +# ============================================================================= +class SessionContextMiddleware(AgentMiddleware[AgentState[Any], SessionContext, Any]): + def wrap_model_call( + self, + request: ModelRequest[SessionContext], + handler: Callable[[ModelRequest[SessionContext]], ModelResponse[Any]], + ) -> ModelResponse[Any]: + return handler(request) + + +def test_mismatched_context_schema() -> None: + # TYPE ERROR: SessionContextMiddleware expects SessionContext, + # but context_schema is UserContext + fake_model = FakeToolCallingModel() + _agent = create_agent( # type: ignore[misc] + model=fake_model, + middleware=[SessionContextMiddleware()], + context_schema=UserContext, + ) + + +# ============================================================================= +# ERROR 4: Backwards compatible middleware with typed context_schema +# ============================================================================= +class BackwardsCompatibleMiddleware(AgentMiddleware): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + return handler(request) + + +def test_backwards_compat_with_context_schema() -> None: + # TYPE ERROR: BackwardsCompatibleMiddleware is AgentMiddleware[..., None] + # but context_schema=UserContext expects AgentMiddleware[..., UserContext] + fake_model = FakeToolCallingModel() + _agent = create_agent( # type: ignore[misc] + model=fake_model, + middleware=[BackwardsCompatibleMiddleware()], + context_schema=UserContext, + ) + + +# ============================================================================= +# ERROR 5: Using wrong response fields +# ============================================================================= +class WrongResponseFieldsMiddleware( + AgentMiddleware[AgentState[AnalysisResult], ContextT, AnalysisResult] +): + def wrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[AnalysisResult]], + ) -> ModelResponse[AnalysisResult]: + response = handler(request) + if response.structured_response is not None: + # TYPE ERROR: 'summary' doesn't exist on AnalysisResult + summary: str = response.structured_response.summary # type: ignore[attr-defined] + _ = summary + return response + + +# ============================================================================= +# ERROR 6: Mismatched ResponseT in method signature +# ============================================================================= +class MismatchedResponseMiddleware( + AgentMiddleware[AgentState[AnalysisResult], ContextT, AnalysisResult] +): + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + # TYPE ERROR: Handler should return ModelResponse[AnalysisResult], not SummaryResult + handler: Callable[[ModelRequest[ContextT]], ModelResponse[SummaryResult]], + ) -> ModelResponse[AnalysisResult]: + # This would fail at runtime - types don't match + return handler(request) # type: ignore[return-value] + + +# ============================================================================= +# ERROR 7: Middleware ResponseT doesn't match response_format +# ============================================================================= +class AnalysisMiddleware(AgentMiddleware[AgentState[AnalysisResult], ContextT, AnalysisResult]): + def wrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[AnalysisResult]], + ) -> ModelResponse[AnalysisResult]: + return handler(request) + + +def test_mismatched_response_format() -> None: + # TODO: TYPE ERROR not yet detected by mypy - AnalysisMiddleware expects AnalysisResult, + # but response_format is SummaryResult. This requires more sophisticated typing. + fake_model = FakeToolCallingModel() + _agent = create_agent( + model=fake_model, + middleware=[AnalysisMiddleware()], + response_format=SummaryResult, + ) + + +# ============================================================================= +# ERROR 8: Wrong return type from wrap_model_call +# ============================================================================= +class WrongReturnTypeMiddleware( + AgentMiddleware[AgentState[AnalysisResult], ContextT, AnalysisResult] +): + def wrap_model_call( # type: ignore[override] + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[AnalysisResult]], + ) -> ModelResponse[SummaryResult]: # TYPE ERROR: Should return ModelResponse[AnalysisResult] + return handler(request) # type: ignore[return-value] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py new file mode 100644 index 00000000000..6878f404147 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware_typing/test_middleware_typing.py @@ -0,0 +1,443 @@ +"""Test file to verify type safety in middleware (ContextT and ResponseT). + +This file demonstrates: +1. Backwards compatible middlewares (no type params specified) - works with defaults +2. Correctly typed middlewares (ContextT/ResponseT match) - full type safety +3. Type errors that are caught when types don't match + +Run type check: uv run --group typing mypy +Run tests: uv run --group test pytest -v + +To see type errors being caught, run: + uv run --group typing mypy .../test_middleware_type_errors.py +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, HumanMessage +from pydantic import BaseModel +from typing_extensions import TypedDict + +from langchain.agents import create_agent +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, + before_model, +) + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from langgraph.graph.state import CompiledStateGraph + from langgraph.runtime import Runtime + + +# ============================================================================= +# Context and Response schemas for testing +# ============================================================================= +class UserContext(TypedDict): + """Context with user information.""" + + user_id: str + user_name: str + + +class SessionContext(TypedDict): + """Different context schema.""" + + session_id: str + expires_at: int + + +class AnalysisResult(BaseModel): + """Structured response schema.""" + + sentiment: str + confidence: float + + +class SummaryResult(BaseModel): + """Different structured response schema.""" + + summary: str + key_points: list[str] + + +# ============================================================================= +# 1. BACKWARDS COMPATIBLE: Middlewares without type parameters +# These work when create_agent has NO context_schema or response_format +# ============================================================================= +class BackwardsCompatibleMiddleware(AgentMiddleware): + """Middleware that doesn't specify type parameters - backwards compatible.""" + + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None: + return None + + def wrap_model_call( + self, + request: ModelRequest, # No type param - backwards compatible! + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + return handler(request) + + +class BackwardsCompatibleMiddleware2(AgentMiddleware): + """Another backwards compatible middleware using ModelRequest without params.""" + + def wrap_model_call( + self, + request: ModelRequest, # Unparameterized - defaults to ModelRequest[None] + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + _ = request.runtime + return handler(request) + + +@before_model +def backwards_compatible_decorator( + state: AgentState[Any], runtime: Runtime[None] +) -> dict[str, Any] | None: + """Decorator middleware without explicit type parameters.""" + return None + + +# ============================================================================= +# 2. CORRECTLY TYPED: Middlewares with explicit ContextT +# These work when create_agent has MATCHING context_schema +# ============================================================================= +class UserContextMiddleware(AgentMiddleware[AgentState[Any], UserContext, Any]): + """Middleware with correctly specified UserContext.""" + + def before_model( + self, state: AgentState[Any], runtime: Runtime[UserContext] + ) -> dict[str, Any] | None: + # Full type safety - IDE knows these fields exist + _user_id: str = runtime.context["user_id"] + _user_name: str = runtime.context["user_name"] + return None + + def wrap_model_call( + self, + request: ModelRequest[UserContext], # Correctly parameterized! + handler: Callable[[ModelRequest[UserContext]], ModelResponse[Any]], + ) -> ModelResponse[Any]: + # request.runtime.context is UserContext - fully typed! + _user_id: str = request.runtime.context["user_id"] + return handler(request) + + +class SessionContextMiddleware(AgentMiddleware[AgentState[Any], SessionContext, Any]): + """Middleware with correctly specified SessionContext.""" + + def wrap_model_call( + self, + request: ModelRequest[SessionContext], + handler: Callable[[ModelRequest[SessionContext]], ModelResponse[Any]], + ) -> ModelResponse[Any]: + _session_id: str = request.runtime.context["session_id"] + _expires: int = request.runtime.context["expires_at"] + return handler(request) + + +# ============================================================================= +# 3. CORRECTLY TYPED: Middlewares with explicit ResponseT +# These work when create_agent has MATCHING response_format +# ============================================================================= +class AnalysisResponseMiddleware( + AgentMiddleware[AgentState[AnalysisResult], ContextT, AnalysisResult] +): + """Middleware with correctly specified AnalysisResult response type.""" + + def wrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[AnalysisResult]], + ) -> ModelResponse[AnalysisResult]: + response = handler(request) + # Full type safety on structured_response + if response.structured_response is not None: + _sentiment: str = response.structured_response.sentiment + _confidence: float = response.structured_response.confidence + return response + + +class SummaryResponseMiddleware( + AgentMiddleware[AgentState[SummaryResult], ContextT, SummaryResult] +): + """Middleware with correctly specified SummaryResult response type.""" + + def wrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[SummaryResult]], + ) -> ModelResponse[SummaryResult]: + response = handler(request) + if response.structured_response is not None: + _summary: str = response.structured_response.summary + _points: list[str] = response.structured_response.key_points + return response + + +# ============================================================================= +# 4. FULLY TYPED: Middlewares with both ContextT and ResponseT +# ============================================================================= +class FullyTypedMiddleware( + AgentMiddleware[AgentState[AnalysisResult], UserContext, AnalysisResult] +): + """Middleware with both ContextT and ResponseT fully specified.""" + + def wrap_model_call( + self, + request: ModelRequest[UserContext], + handler: Callable[[ModelRequest[UserContext]], ModelResponse[AnalysisResult]], + ) -> ModelResponse[AnalysisResult]: + # Access context with full type safety + _user_id: str = request.runtime.context["user_id"] + + response = handler(request) + + # Access structured response with full type safety + if response.structured_response is not None: + _sentiment: str = response.structured_response.sentiment + + return response + + +# ============================================================================= +# 5. FLEXIBLE MIDDLEWARE: Works with any ContextT/ResponseT using Generic +# ============================================================================= +class FlexibleMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]): + """Middleware that works with any ContextT and ResponseT.""" + + def wrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]], + ) -> ModelResponse[ResponseT]: + # Can't access specific fields, but works with any schemas + _ = request.runtime + return handler(request) + + +# ============================================================================= +# 6. CREATE_AGENT INTEGRATION TESTS +# ============================================================================= +@pytest.fixture +def fake_model() -> GenericFakeChatModel: + """Create a fake model for testing.""" + return GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) + + +def test_create_agent_no_context_schema(fake_model: GenericFakeChatModel) -> None: + """Backwards compatible: No context_schema means ContextT=None.""" + agent: CompiledStateGraph[Any, None, Any, Any] = create_agent( + model=fake_model, + middleware=[ + BackwardsCompatibleMiddleware(), + BackwardsCompatibleMiddleware2(), + backwards_compatible_decorator, + ], + # No context_schema - backwards compatible + ) + assert agent is not None + + +def test_create_agent_with_user_context(fake_model: GenericFakeChatModel) -> None: + """Typed: context_schema=UserContext requires matching middleware.""" + agent: CompiledStateGraph[Any, UserContext, Any, Any] = create_agent( + model=fake_model, + middleware=[UserContextMiddleware()], # Matches UserContext + context_schema=UserContext, + ) + assert agent is not None + + +def test_create_agent_with_session_context(fake_model: GenericFakeChatModel) -> None: + """Typed: context_schema=SessionContext requires matching middleware.""" + agent: CompiledStateGraph[Any, SessionContext, Any, Any] = create_agent( + model=fake_model, + middleware=[SessionContextMiddleware()], # Matches SessionContext + context_schema=SessionContext, + ) + assert agent is not None + + +def test_create_agent_with_flexible_middleware(fake_model: GenericFakeChatModel) -> None: + """Flexible middleware works with any context_schema.""" + # With UserContext + agent1: CompiledStateGraph[Any, UserContext, Any, Any] = create_agent( + model=fake_model, + middleware=[FlexibleMiddleware[UserContext, Any]()], + context_schema=UserContext, + ) + assert agent1 is not None + + # With SessionContext + agent2: CompiledStateGraph[Any, SessionContext, Any, Any] = create_agent( + model=fake_model, + middleware=[FlexibleMiddleware[SessionContext, Any]()], + context_schema=SessionContext, + ) + assert agent2 is not None + + +def test_create_agent_with_response_middleware(fake_model: GenericFakeChatModel) -> None: + """Middleware with ResponseT works with response_format.""" + agent = create_agent( + model=fake_model, + middleware=[AnalysisResponseMiddleware()], + response_format=AnalysisResult, + ) + assert agent is not None + + +def test_create_agent_fully_typed(fake_model: GenericFakeChatModel) -> None: + """Fully typed middleware with both ContextT and ResponseT.""" + agent = create_agent( + model=fake_model, + middleware=[FullyTypedMiddleware()], + context_schema=UserContext, + response_format=AnalysisResult, + ) + assert agent is not None + + +# ============================================================================= +# 7. ASYNC VARIANTS +# ============================================================================= +class AsyncUserContextMiddleware(AgentMiddleware[AgentState[Any], UserContext, Any]): + """Async middleware with correctly typed ContextT.""" + + async def abefore_model( + self, state: AgentState[Any], runtime: Runtime[UserContext] + ) -> dict[str, Any] | None: + _user_name: str = runtime.context["user_name"] + return None + + async def awrap_model_call( + self, + request: ModelRequest[UserContext], + handler: Callable[[ModelRequest[UserContext]], Awaitable[ModelResponse[Any]]], + ) -> ModelResponse[Any]: + _user_id: str = request.runtime.context["user_id"] + return await handler(request) + + +class AsyncResponseMiddleware( + AgentMiddleware[AgentState[AnalysisResult], ContextT, AnalysisResult] +): + """Async middleware with correctly typed ResponseT.""" + + async def awrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[AnalysisResult]]], + ) -> ModelResponse[AnalysisResult]: + response = await handler(request) + if response.structured_response is not None: + _sentiment: str = response.structured_response.sentiment + return response + + +def test_async_middleware_with_context(fake_model: GenericFakeChatModel) -> None: + """Async middleware with typed context.""" + agent: CompiledStateGraph[Any, UserContext, Any, Any] = create_agent( + model=fake_model, + middleware=[AsyncUserContextMiddleware()], + context_schema=UserContext, + ) + assert agent is not None + + +def test_async_middleware_with_response(fake_model: GenericFakeChatModel) -> None: + """Async middleware with typed response.""" + agent = create_agent( + model=fake_model, + middleware=[AsyncResponseMiddleware()], + response_format=AnalysisResult, + ) + assert agent is not None + + +# ============================================================================= +# 8. MODEL_REQUEST AND MODEL_RESPONSE TESTS +# ============================================================================= +def test_model_request_preserves_context_type() -> None: + """Test that ModelRequest.override() preserves ContextT.""" + request: ModelRequest[UserContext] = ModelRequest( + model=None, # type: ignore[arg-type] + messages=[HumanMessage(content="test")], + runtime=None, + ) + + # Override should preserve the type parameter + new_request: ModelRequest[UserContext] = request.override( + messages=[HumanMessage(content="updated")] + ) + + assert type(request) is type(new_request) + + +def test_model_request_backwards_compatible() -> None: + """Test that ModelRequest can be instantiated without type params.""" + request = ModelRequest( + model=None, # type: ignore[arg-type] + messages=[HumanMessage(content="test")], + ) + + assert request.messages[0].content == "test" + + +def test_model_request_explicit_none() -> None: + """Test ModelRequest[None] is same as unparameterized ModelRequest.""" + request1: ModelRequest[None] = ModelRequest( + model=None, # type: ignore[arg-type] + messages=[HumanMessage(content="test")], + ) + + request2: ModelRequest = ModelRequest( + model=None, # type: ignore[arg-type] + messages=[HumanMessage(content="test")], + ) + + assert type(request1) is type(request2) + + +def test_model_response_with_response_type() -> None: + """Test that ModelResponse preserves ResponseT.""" + response: ModelResponse[AnalysisResult] = ModelResponse( + result=[AIMessage(content="test")], + structured_response=AnalysisResult(sentiment="positive", confidence=0.9), + ) + + # Type checker knows structured_response is AnalysisResult | None + if response.structured_response is not None: + _sentiment: str = response.structured_response.sentiment + _confidence: float = response.structured_response.confidence + + +def test_model_response_without_structured() -> None: + """Test ModelResponse without structured response.""" + response: ModelResponse[Any] = ModelResponse( + result=[AIMessage(content="test")], + structured_response=None, + ) + + assert response.structured_response is None + + +def test_model_response_backwards_compatible() -> None: + """Test that ModelResponse can be instantiated without type params.""" + response = ModelResponse( + result=[AIMessage(content="test")], + ) + + assert response.structured_response is None