From 4d118777bc6f9c26f02b5ff843b7e0ab074b2fc2 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 18 Sep 2025 16:07:16 -0400 Subject: [PATCH] feat(langchain): dynamic system prompt middleware (#33006) # Changes ## Adds support for `DynamicSystemPromptMiddleware` ```py from langchain.agents.middleware import DynamicSystemPromptMiddleware from langgraph.runtime import Runtime from typing_extensions import TypedDict class Context(TypedDict): user_name: str def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str: user_name = runtime.context.get("user_name", "n/a") return f"You are a helpful assistant. Always address the user by their name: {user_name}" middleware = DynamicSystemPromptMiddleware(system_prompt) ``` ## Adds support for `runtime` in middleware hooks ```py class AgentMiddleware(Generic[StateT, ContextT]): def modify_model_request( self, request: ModelRequest, state: StateT, runtime: Runtime[ContextT], # Optional runtime parameter ) -> ModelRequest: # upgrade model if runtime.context.subscription is `top-tier` or whatever ``` ## Adds support for omitting state attributes from input / output schemas ```py from typing import Annotated, NotRequired from langchain.agents.middleware.types import PrivateStateAttr, OmitFromInput, OmitFromOutput class CustomState(AgentState): # Private field - not in input or output schemas internal_counter: NotRequired[Annotated[int, PrivateStateAttr]] # Input-only field - not in output schema user_input: NotRequired[Annotated[str, OmitFromOutput]] # Output-only field - not in input schema computed_result: NotRequired[Annotated[str, OmitFromInput]] ``` ## Additionally * Removes filtering of state before passing into middleware hooks Typing is not foolproof here, still need to figure out some of the generics stuff w/ state and context schema extensions for middleware. TODO: * More docs for middleware, should hold off on this until other prios like MCP and deepagents are met --------- Co-authored-by: Eugene Yurtsev --- .../langchain/agents/middleware/__init__.py | 3 + .../middleware/dynamic_system_prompt.py | 105 +++++++++++++ .../agents/middleware/human_in_the_loop.py | 2 +- .../agents/middleware/prompt_caching.py | 7 +- .../agents/middleware/summarization.py | 2 +- .../langchain/agents/middleware/types.py | 60 ++++++-- .../langchain/agents/middleware_agent.py | 103 ++++++++----- .../agents/test_middleware_agent.py | 140 +++++++++++++++++- 8 files changed, 370 insertions(+), 52 deletions(-) create mode 100644 libs/langchain_v1/langchain/agents/middleware/dynamic_system_prompt.py diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index efc31b6eb73..e554d2aed4e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -1,5 +1,6 @@ """Middleware plugins for agents.""" +from .dynamic_system_prompt import DynamicSystemPromptMiddleware from .human_in_the_loop import HumanInTheLoopMiddleware from .prompt_caching import AnthropicPromptCachingMiddleware from .summarization import SummarizationMiddleware @@ -8,7 +9,9 @@ from .types import AgentMiddleware, AgentState, ModelRequest __all__ = [ "AgentMiddleware", "AgentState", + # should move to langchain-anthropic if we decide to keep it "AnthropicPromptCachingMiddleware", + "DynamicSystemPromptMiddleware", "HumanInTheLoopMiddleware", "ModelRequest", "SummarizationMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/dynamic_system_prompt.py b/libs/langchain_v1/langchain/agents/middleware/dynamic_system_prompt.py new file mode 100644 index 00000000000..d1f3f8b03b5 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/dynamic_system_prompt.py @@ -0,0 +1,105 @@ +"""Dynamic System Prompt Middleware. + +Allows setting the system prompt dynamically right before each model invocation. +Useful when the prompt depends on the current agent state or per-invocation context. +""" + +from __future__ import annotations + +from inspect import signature +from typing import TYPE_CHECKING, Protocol, TypeAlias, cast + +from langgraph.typing import ContextT + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ModelRequest, +) + +if TYPE_CHECKING: + from langgraph.runtime import Runtime + + +class DynamicSystemPromptWithoutRuntime(Protocol): + """Dynamic system prompt without runtime in call signature.""" + + def __call__(self, state: AgentState) -> str: + """Return the system prompt for the next model call.""" + ... + + +class DynamicSystemPromptWithRuntime(Protocol[ContextT]): + """Dynamic system prompt with runtime in call signature.""" + + def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str: + """Return the system prompt for the next model call.""" + ... + + +DynamicSystemPrompt: TypeAlias = ( + DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT] +) + + +class DynamicSystemPromptMiddleware(AgentMiddleware): + """Dynamic System Prompt Middleware. + + Allows setting the system prompt dynamically right before each model invocation. + Useful when the prompt depends on the current agent state or per-invocation context. + + Example: + ```python + from langchain.agents.middleware import DynamicSystemPromptMiddleware + + + class Context(TypedDict): + user_name: str + + + def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str: + user_name = runtime.context.get("user_name", "n/a") + return ( + f"You are a helpful assistant. Always address the user by their name: {user_name}" + ) + + + middleware = DynamicSystemPromptMiddleware(system_prompt) + ``` + """ + + _accepts_runtime: bool + + def __init__( + self, + dynamic_system_prompt: DynamicSystemPrompt[ContextT], + ) -> None: + """Initialize the dynamic system prompt middleware. + + Args: + dynamic_system_prompt: Function that receives the current agent state + and optionally runtime with context, and returns the system prompt for + the next model call. Returns a string. + """ + super().__init__() + self.dynamic_system_prompt = dynamic_system_prompt + self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters + + def modify_model_request( + self, + request: ModelRequest, + state: AgentState, + runtime: Runtime[ContextT], + ) -> ModelRequest: + """Modify the model request to include the dynamic system prompt.""" + if self._accepts_runtime: + system_prompt = cast( + "DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt + )(state, runtime) + else: + system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)( + state + ) + + request.system_prompt = system_prompt + return request diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index 46ac1a75990..18e58868917 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -143,7 +143,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware): self.tool_configs = resolved_tool_configs self.description_prefix = description_prefix - def after_model(self, state: AgentState) -> dict[str, Any] | None: + def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override] """Trigger HITL flows for relevant tool calls after an AIMessage.""" messages = state["messages"] if not messages: diff --git a/libs/langchain_v1/langchain/agents/middleware/prompt_caching.py b/libs/langchain_v1/langchain/agents/middleware/prompt_caching.py index 2f7a1817995..ab043e6fcf0 100644 --- a/libs/langchain_v1/langchain/agents/middleware/prompt_caching.py +++ b/libs/langchain_v1/langchain/agents/middleware/prompt_caching.py @@ -2,7 +2,7 @@ from typing import Literal -from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest +from langchain.agents.middleware.types import AgentMiddleware, ModelRequest class AnthropicPromptCachingMiddleware(AgentMiddleware): @@ -32,7 +32,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware): self.ttl = ttl self.min_messages_to_cache = min_messages_to_cache - def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002 + def modify_model_request( # type: ignore[override] + self, + request: ModelRequest, + ) -> ModelRequest: """Modify the model request to add cache control blocks.""" try: from langchain_anthropic import ChatAnthropic diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index c716e2c2dec..814cc83a908 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -98,7 +98,7 @@ class SummarizationMiddleware(AgentMiddleware): self.summary_prompt = summary_prompt self.summary_prefix = summary_prefix - def before_model(self, state: AgentState) -> dict[str, Any] | None: + def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override] """Process messages before model invocation, potentially triggering summarization.""" messages = state["messages"] self._ensure_message_ids(messages) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index c84f4954ed1..5a6c7ecf56c 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -8,15 +8,27 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast # needed as top level import for pydantic schema generation on AgentState from langchain_core.messages import AnyMessage # noqa: TC002 from langgraph.channels.ephemeral_value import EphemeralValue -from langgraph.graph.message import Messages, add_messages +from langgraph.graph.message import add_messages +from langgraph.runtime import Runtime +from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict, TypeVar if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.tools import BaseTool + from langgraph.runtime import Runtime from langchain.agents.structured_output import ResponseFormat +__all__ = [ + "AgentMiddleware", + "AgentState", + "ContextT", + "ModelRequest", + "OmitFromSchema", + "PublicAgentState", +] + JumpTo = Literal["tools", "model", "__end__"] """Destination to jump to when a middleware node returns.""" @@ -36,26 +48,49 @@ class ModelRequest: model_settings: dict[str, Any] = field(default_factory=dict) +@dataclass +class OmitFromSchema: + """Annotation used to mark state attributes as omitted from input or output schemas.""" + + input: bool = True + """Whether to omit the attribute from the input schema.""" + + output: bool = True + """Whether to omit the attribute from the output schema.""" + + +OmitFromInput = OmitFromSchema(input=True, output=False) +"""Annotation used to mark state attributes as omitted from input schema.""" + +OmitFromOutput = OmitFromSchema(input=False, output=True) +"""Annotation used to mark state attributes as omitted from output schema.""" + +PrivateStateAttr = OmitFromSchema(input=True, output=True) +"""Annotation used to mark state attributes as purely internal for a given middleware.""" + + class AgentState(TypedDict, Generic[ResponseT]): """State schema for the agent.""" messages: Required[Annotated[list[AnyMessage], add_messages]] - model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]] - jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]] + jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]] response: NotRequired[ResponseT] class PublicAgentState(TypedDict, Generic[ResponseT]): - """Input / output schema for the agent.""" + """Public state schema for the agent. - messages: Required[Messages] + Just used for typing purposes. + """ + + messages: Required[Annotated[list[AnyMessage], add_messages]] response: NotRequired[ResponseT] -StateT = TypeVar("StateT", bound=AgentState) +StateT = TypeVar("StateT", bound=AgentState, default=AgentState) -class AgentMiddleware(Generic[StateT]): +class AgentMiddleware(Generic[StateT, ContextT]): """Base middleware class for an agent. Subclass this and implement any of the defined methods to customize agent behavior @@ -68,12 +103,17 @@ class AgentMiddleware(Generic[StateT]): tools: list[BaseTool] """Additional tools registered by the middleware.""" - def before_model(self, state: StateT) -> dict[str, Any] | None: + def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run before the model is called.""" - def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002 + def modify_model_request( + self, + request: ModelRequest, + state: StateT, # noqa: ARG002 + runtime: Runtime[ContextT], # noqa: ARG002 + ) -> ModelRequest: """Logic to modify request kwargs before the model is called.""" return request - def after_model(self, state: StateT) -> dict[str, Any] | None: + def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run after the model is called.""" diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 4caa6c8af62..f94ab417ba7 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -2,7 +2,8 @@ import itertools from collections.abc import Callable, Sequence -from typing import Any, cast +from inspect import signature +from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage @@ -10,19 +11,19 @@ from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langgraph.constants import END, START from langgraph.graph.state import StateGraph +from langgraph.runtime import Runtime from langgraph.types import Send from langgraph.typing import ContextT -from typing_extensions import TypedDict, TypeVar +from typing_extensions import NotRequired, Required, TypedDict, TypeVar from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, JumpTo, ModelRequest, + OmitFromSchema, PublicAgentState, ) - -# Import structured output classes from the old implementation from langchain.agents.structured_output import ( MultipleStructuredOutputsError, OutputToolBinding, @@ -38,26 +39,49 @@ from langchain.chat_models import init_chat_model STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." -def _merge_state_schemas(schemas: list[type]) -> type: - """Merge multiple TypedDict schemas into a single schema with all fields.""" - if not schemas: - return AgentState +def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type: + """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations. + Args: + schemas: List of schema types to merge + schema_name: Name for the generated TypedDict + omit_flag: If specified, omit fields with this flag set ('input' or 'output') + """ all_annotations = {} for schema in schemas: - all_annotations.update(schema.__annotations__) + hints = get_type_hints(schema, include_extras=True) - return TypedDict("MergedState", all_annotations) # type: ignore[operator] + for field_name, field_type in hints.items(): + should_omit = False + + if omit_flag: + # Check for omission in the annotation metadata + metadata = _extract_metadata(field_type) + for meta in metadata: + if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True: + should_omit = True + break + + if not should_omit: + all_annotations[field_name] = field_type + + return TypedDict(schema_name, all_annotations) # type: ignore[operator] -def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]: - """Filter state to only include fields defined in the given schema.""" - if not hasattr(schema, "__annotations__"): - return state +def _extract_metadata(type_: type) -> list: + """Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers.""" + # Handle Required[Annotated[...]] or NotRequired[Annotated[...]] + if get_origin(type_) in (Required, NotRequired): + inner_type = get_args(type_)[0] + if get_origin(inner_type) is Annotated: + return list(get_args(inner_type)[1:]) - schema_fields = set(schema.__annotations__.keys()) - return {k: v for k, v in state.items() if k in schema_fields} + # Handle direct Annotated[...] + elif get_origin(type_) is Annotated: + return list(get_args(type_)[1:]) + + return [] def _supports_native_structured_output(model: str | BaseChatModel) -> bool: @@ -114,7 +138,7 @@ def create_agent( # noqa: PLR0915 model: str | BaseChatModel, tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None, system_prompt: str | None = None, - middleware: Sequence[AgentMiddleware] = (), + middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (), response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, context_schema: type[ContextT] | None = None, ) -> StateGraph[ @@ -199,16 +223,16 @@ def create_agent( # noqa: PLR0915 m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model ] - # Collect all middleware state schemas and create merged schema - merged_state_schema: type[AgentState] = _merge_state_schemas( - [m.state_schema for m in middleware] - ) + state_schemas = {m.state_schema for m in middleware} + state_schemas.add(AgentState) # create graph, add nodes - graph = StateGraph( - merged_state_schema, - input_schema=PublicAgentState, - output_schema=PublicAgentState, + graph: StateGraph[ + AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] + ] = StateGraph( + state_schema=_resolve_schema(state_schemas, "StateSchema", None), + input_schema=_resolve_schema(state_schemas, "InputSchema", "input"), + output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"), context_schema=context_schema, ) @@ -318,7 +342,14 @@ def create_agent( # noqa: PLR0915 ) return request.model.bind(**request.model_settings) - def model_request(state: dict[str, Any]) -> dict[str, Any]: + model_request_signatures: list[ + tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]] + ] = [ + ("runtime" in signature(m.modify_model_request).parameters, m) + for m in middleware_w_modify_model_request + ] + + def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: """Sync model request handler with sequential middleware processing.""" request = ModelRequest( model=model, @@ -330,10 +361,11 @@ def create_agent( # noqa: PLR0915 ) # Apply modify_model_request middleware in sequence - for m in middleware_w_modify_model_request: - # Filter state to only include fields defined in this middleware's schema - filtered_state = _filter_state_for_schema(state, m.state_schema) - request = m.modify_model_request(request, filtered_state) + for use_runtime, m in model_request_signatures: + if use_runtime: + m.modify_model_request(request, state, runtime) + else: + m.modify_model_request(request, state) # type: ignore[call-arg] # Get the final model and messages model_ = _get_bound_model(request) @@ -344,7 +376,7 @@ def create_agent( # noqa: PLR0915 output = model_.invoke(messages) return _handle_model_output(output) - async def amodel_request(state: dict[str, Any]) -> dict[str, Any]: + async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: """Async model request handler with sequential middleware processing.""" # Start with the base model request request = ModelRequest( @@ -357,10 +389,11 @@ def create_agent( # noqa: PLR0915 ) # Apply modify_model_request middleware in sequence - for m in middleware_w_modify_model_request: - # Filter state to only include fields defined in this middleware's schema - filtered_state = _filter_state_for_schema(state, m.state_schema) - request = m.modify_model_request(request, filtered_state) + for use_runtime, m in model_request_signatures: + if use_runtime: + m.modify_model_request(request, state, runtime) + else: + m.modify_model_request(request, state) # type: ignore[call-arg] # Get the final model and messages model_ = _get_bound_model(request) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 58138ac5227..a381dc42f96 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -1,9 +1,12 @@ +from typing_extensions import TypedDict import pytest from typing import Any from unittest.mock import patch from syrupy.assertion import SnapshotAssertion +from langgraph.runtime import Runtime +from typing_extensions import Annotated, TypedDict from pydantic import BaseModel, Field from langchain_core.language_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel @@ -15,17 +18,23 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.tools import tool -from langgraph.types import Command from langchain.agents.middleware_agent import create_agent from langchain.agents.middleware.human_in_the_loop import ( HumanInTheLoopMiddleware, - HumanInTheLoopConfig, ActionRequest, ) from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware from langchain.agents.middleware.summarization import SummarizationMiddleware -from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + AgentState, + OmitFromInput, + OmitFromOutput, + PrivateStateAttr, +) +from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.memory import InMemorySaver @@ -1201,3 +1210,128 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls(): assert hasattr(result["response"], "temperature") assert result["response"].temperature == 72.0 assert result["response"].condition == "sunny" + + +# Tests for DynamicSystemPromptMiddleware +def test_dynamic_system_prompt_middleware_basic() -> None: + """Test basic functionality of DynamicSystemPromptMiddleware.""" + + def dynamic_system_prompt(state: AgentState) -> str: + messages = state.get("messages", []) + if messages: + return f"You are a helpful assistant. Message count: {len(messages)}" + return "You are a helpful assistant. No messages yet." + + middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt) + + # Test with empty state + empty_state = {"messages": []} + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt="Original prompt", + messages=[], + tool_choice=None, + tools=[], + response_format=None, + ) + + modified_request = middleware.modify_model_request(request, empty_state, None) + assert modified_request.system_prompt == "You are a helpful assistant. No messages yet." + + state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]} + modified_request = middleware.modify_model_request(request, state_with_messages, None) + assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2" + + +def test_dynamic_system_prompt_middleware_with_context() -> None: + """Test DynamicSystemPromptMiddleware with runtime context.""" + + class MockContext(TypedDict): + user_role: str + + def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str: + base_prompt = "You are a helpful assistant." + if runtime and hasattr(runtime, "context"): + user_role = runtime.context.get("user_role", "user") + return f"{base_prompt} User role: {user_role}" + return base_prompt + + middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt) + + # Create a mock runtime with context + class MockRuntime: + def __init__(self, context): + self.context = context + + mock_runtime = MockRuntime(context={"user_role": "admin"}) + + request = ModelRequest( + model=FakeToolCallingModel(), + system_prompt="Original prompt", + messages=[HumanMessage("Test")], + tool_choice=None, + tools=[], + response_format=None, + ) + + state = {"messages": [HumanMessage("Test")]} + modified_request = middleware.modify_model_request(request, state, mock_runtime) + + assert modified_request.system_prompt == "You are a helpful assistant. User role: admin" + + +def test_public_private_state_for_custom_middleware() -> None: + """Test public and private state for custom middleware.""" + + class CustomState(AgentState): + omit_input: Annotated[str, OmitFromInput] + omit_output: Annotated[str, OmitFromOutput] + private_state: Annotated[str, PrivateStateAttr] + + class CustomMiddleware(AgentMiddleware[CustomState]): + state_schema: type[CustomState] = CustomState + + def before_model(self, state: CustomState) -> dict[str, Any]: + assert "omit_input" not in state + assert "omit_output" in state + assert "private_state" not in state + return {"omit_input": "test", "omit_output": "test", "private_state": "test"} + + agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()]) + agent = agent.compile() + result = agent.invoke( + { + "messages": [HumanMessage("Hello")], + "omit_input": "test in", + "private_state": "test in", + "omit_output": "test in", + } + ) + assert "omit_input" in result + assert "omit_output" not in result + assert "private_state" not in result + + +def test_runtime_injected_into_middleware() -> None: + """Test that the runtime is injected into the middleware.""" + + class CustomMiddleware(AgentMiddleware): + def before_model(self, state: AgentState, runtime: Runtime) -> None: + assert runtime is not None + return None + + def modify_model_request( + self, request: ModelRequest, state: AgentState, runtime: Runtime + ) -> ModelRequest: + assert runtime is not None + return request + + def after_model(self, state: AgentState, runtime: Runtime) -> None: + assert runtime is not None + return None + + middleware = CustomMiddleware() + + agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()]) + agent = agent.compile() + agent.invoke({"messages": [HumanMessage("Hello")]})