feat(langchain): dynamic system prompt middleware (#33006)

# Changes

## Adds support for `DynamicSystemPromptMiddleware`

```py
from langchain.agents.middleware import DynamicSystemPromptMiddleware
from langgraph.runtime import Runtime
from typing_extensions import TypedDict

class Context(TypedDict):
    user_name: str

def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
    user_name = runtime.context.get("user_name", "n/a")
    return f"You are a helpful assistant. Always address the user by their name: {user_name}"

middleware = DynamicSystemPromptMiddleware(system_prompt)
```

## Adds support for `runtime` in middleware hooks

```py
class AgentMiddleware(Generic[StateT, ContextT]):
    def modify_model_request(
        self,
        request: ModelRequest,
        state: StateT,
        runtime: Runtime[ContextT],  # Optional runtime parameter
    ) -> ModelRequest:
        # upgrade model if runtime.context.subscription is `top-tier` or whatever
```

## Adds support for omitting state attributes from input / output
schemas

```py
from typing import Annotated, NotRequired
from langchain.agents.middleware.types import PrivateStateAttr, OmitFromInput, OmitFromOutput

class CustomState(AgentState):
    # Private field - not in input or output schemas
    internal_counter: NotRequired[Annotated[int, PrivateStateAttr]]
    
    # Input-only field - not in output schema
    user_input: NotRequired[Annotated[str, OmitFromOutput]]
    
    # Output-only field - not in input schema  
    computed_result: NotRequired[Annotated[str, OmitFromInput]]
```

## Additionally
* Removes filtering of state before passing into middleware hooks

Typing is not foolproof here, still need to figure out some of the
generics stuff w/ state and context schema extensions for middleware.

TODO:
* More docs for middleware, should hold off on this until other prios
like MCP and deepagents are met

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Sydney Runkle
2025-09-18 16:07:16 -04:00
committed by GitHub
parent f158cea1e8
commit 4d118777bc
8 changed files with 370 additions and 52 deletions

View File

@@ -1,5 +1,6 @@
"""Middleware plugins for agents.""" """Middleware plugins for agents."""
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
from .human_in_the_loop import HumanInTheLoopMiddleware from .human_in_the_loop import HumanInTheLoopMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware from .summarization import SummarizationMiddleware
@@ -8,7 +9,9 @@ from .types import AgentMiddleware, AgentState, ModelRequest
__all__ = [ __all__ = [
"AgentMiddleware", "AgentMiddleware",
"AgentState", "AgentState",
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware", "AnthropicPromptCachingMiddleware",
"DynamicSystemPromptMiddleware",
"HumanInTheLoopMiddleware", "HumanInTheLoopMiddleware",
"ModelRequest", "ModelRequest",
"SummarizationMiddleware", "SummarizationMiddleware",

View File

@@ -0,0 +1,105 @@
"""Dynamic System Prompt Middleware.
Allows setting the system prompt dynamically right before each model invocation.
Useful when the prompt depends on the current agent state or per-invocation context.
"""
from __future__ import annotations
from inspect import signature
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
from langgraph.typing import ContextT
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
)
if TYPE_CHECKING:
from langgraph.runtime import Runtime
class DynamicSystemPromptWithoutRuntime(Protocol):
"""Dynamic system prompt without runtime in call signature."""
def __call__(self, state: AgentState) -> str:
"""Return the system prompt for the next model call."""
...
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
"""Dynamic system prompt with runtime in call signature."""
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
"""Return the system prompt for the next model call."""
...
DynamicSystemPrompt: TypeAlias = (
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
)
class DynamicSystemPromptMiddleware(AgentMiddleware):
"""Dynamic System Prompt Middleware.
Allows setting the system prompt dynamically right before each model invocation.
Useful when the prompt depends on the current agent state or per-invocation context.
Example:
```python
from langchain.agents.middleware import DynamicSystemPromptMiddleware
class Context(TypedDict):
user_name: str
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
user_name = runtime.context.get("user_name", "n/a")
return (
f"You are a helpful assistant. Always address the user by their name: {user_name}"
)
middleware = DynamicSystemPromptMiddleware(system_prompt)
```
"""
_accepts_runtime: bool
def __init__(
self,
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
) -> None:
"""Initialize the dynamic system prompt middleware.
Args:
dynamic_system_prompt: Function that receives the current agent state
and optionally runtime with context, and returns the system prompt for
the next model call. Returns a string.
"""
super().__init__()
self.dynamic_system_prompt = dynamic_system_prompt
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
def modify_model_request(
self,
request: ModelRequest,
state: AgentState,
runtime: Runtime[ContextT],
) -> ModelRequest:
"""Modify the model request to include the dynamic system prompt."""
if self._accepts_runtime:
system_prompt = cast(
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
)(state, runtime)
else:
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
state
)
request.system_prompt = system_prompt
return request

View File

@@ -143,7 +143,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
self.tool_configs = resolved_tool_configs self.tool_configs = resolved_tool_configs
self.description_prefix = description_prefix self.description_prefix = description_prefix
def after_model(self, state: AgentState) -> dict[str, Any] | None: def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
"""Trigger HITL flows for relevant tool calls after an AIMessage.""" """Trigger HITL flows for relevant tool calls after an AIMessage."""
messages = state["messages"] messages = state["messages"]
if not messages: if not messages:

View File

@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
class AnthropicPromptCachingMiddleware(AgentMiddleware): class AnthropicPromptCachingMiddleware(AgentMiddleware):
@@ -32,7 +32,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
self.ttl = ttl self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache self.min_messages_to_cache = min_messages_to_cache
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002 def modify_model_request( # type: ignore[override]
self,
request: ModelRequest,
) -> ModelRequest:
"""Modify the model request to add cache control blocks.""" """Modify the model request to add cache control blocks."""
try: try:
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic

View File

@@ -98,7 +98,7 @@ class SummarizationMiddleware(AgentMiddleware):
self.summary_prompt = summary_prompt self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix self.summary_prefix = summary_prefix
def before_model(self, state: AgentState) -> dict[str, Any] | None: def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
"""Process messages before model invocation, potentially triggering summarization.""" """Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"] messages = state["messages"]
self._ensure_message_ids(messages) self._ensure_message_ids(messages)

View File

@@ -8,15 +8,27 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
# needed as top level import for pydantic schema generation on AgentState # needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002 from langchain_core.messages import AnyMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import Messages, add_messages from langgraph.graph.message import add_messages
from langgraph.runtime import Runtime
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar from typing_extensions import NotRequired, Required, TypedDict, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from langchain.agents.structured_output import ResponseFormat from langchain.agents.structured_output import ResponseFormat
__all__ = [
"AgentMiddleware",
"AgentState",
"ContextT",
"ModelRequest",
"OmitFromSchema",
"PublicAgentState",
]
JumpTo = Literal["tools", "model", "__end__"] JumpTo = Literal["tools", "model", "__end__"]
"""Destination to jump to when a middleware node returns.""" """Destination to jump to when a middleware node returns."""
@@ -36,26 +48,49 @@ class ModelRequest:
model_settings: dict[str, Any] = field(default_factory=dict) model_settings: dict[str, Any] = field(default_factory=dict)
@dataclass
class OmitFromSchema:
"""Annotation used to mark state attributes as omitted from input or output schemas."""
input: bool = True
"""Whether to omit the attribute from the input schema."""
output: bool = True
"""Whether to omit the attribute from the output schema."""
OmitFromInput = OmitFromSchema(input=True, output=False)
"""Annotation used to mark state attributes as omitted from input schema."""
OmitFromOutput = OmitFromSchema(input=False, output=True)
"""Annotation used to mark state attributes as omitted from output schema."""
PrivateStateAttr = OmitFromSchema(input=True, output=True)
"""Annotation used to mark state attributes as purely internal for a given middleware."""
class AgentState(TypedDict, Generic[ResponseT]): class AgentState(TypedDict, Generic[ResponseT]):
"""State schema for the agent.""" """State schema for the agent."""
messages: Required[Annotated[list[AnyMessage], add_messages]] messages: Required[Annotated[list[AnyMessage], add_messages]]
model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]] jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
response: NotRequired[ResponseT] response: NotRequired[ResponseT]
class PublicAgentState(TypedDict, Generic[ResponseT]): class PublicAgentState(TypedDict, Generic[ResponseT]):
"""Input / output schema for the agent.""" """Public state schema for the agent.
messages: Required[Messages] Just used for typing purposes.
"""
messages: Required[Annotated[list[AnyMessage], add_messages]]
response: NotRequired[ResponseT] response: NotRequired[ResponseT]
StateT = TypeVar("StateT", bound=AgentState) StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
class AgentMiddleware(Generic[StateT]): class AgentMiddleware(Generic[StateT, ContextT]):
"""Base middleware class for an agent. """Base middleware class for an agent.
Subclass this and implement any of the defined methods to customize agent behavior Subclass this and implement any of the defined methods to customize agent behavior
@@ -68,12 +103,17 @@ class AgentMiddleware(Generic[StateT]):
tools: list[BaseTool] tools: list[BaseTool]
"""Additional tools registered by the middleware.""" """Additional tools registered by the middleware."""
def before_model(self, state: StateT) -> dict[str, Any] | None: def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called.""" """Logic to run before the model is called."""
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002 def modify_model_request(
self,
request: ModelRequest,
state: StateT, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
) -> ModelRequest:
"""Logic to modify request kwargs before the model is called.""" """Logic to modify request kwargs before the model is called."""
return request return request
def after_model(self, state: StateT) -> dict[str, Any] | None: def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the model is called.""" """Logic to run after the model is called."""

View File

@@ -2,7 +2,8 @@
import itertools import itertools
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, cast from inspect import signature
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
@@ -10,19 +11,19 @@ from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.constants import END, START from langgraph.constants import END, START
from langgraph.graph.state import StateGraph from langgraph.graph.state import StateGraph
from langgraph.runtime import Runtime
from langgraph.types import Send from langgraph.types import Send
from langgraph.typing import ContextT from langgraph.typing import ContextT
from typing_extensions import TypedDict, TypeVar from typing_extensions import NotRequired, Required, TypedDict, TypeVar
from langchain.agents.middleware.types import ( from langchain.agents.middleware.types import (
AgentMiddleware, AgentMiddleware,
AgentState, AgentState,
JumpTo, JumpTo,
ModelRequest, ModelRequest,
OmitFromSchema,
PublicAgentState, PublicAgentState,
) )
# Import structured output classes from the old implementation
from langchain.agents.structured_output import ( from langchain.agents.structured_output import (
MultipleStructuredOutputsError, MultipleStructuredOutputsError,
OutputToolBinding, OutputToolBinding,
@@ -38,26 +39,49 @@ from langchain.chat_models import init_chat_model
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
def _merge_state_schemas(schemas: list[type]) -> type: def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
"""Merge multiple TypedDict schemas into a single schema with all fields.""" """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
if not schemas:
return AgentState
Args:
schemas: List of schema types to merge
schema_name: Name for the generated TypedDict
omit_flag: If specified, omit fields with this flag set ('input' or 'output')
"""
all_annotations = {} all_annotations = {}
for schema in schemas: for schema in schemas:
all_annotations.update(schema.__annotations__) hints = get_type_hints(schema, include_extras=True)
return TypedDict("MergedState", all_annotations) # type: ignore[operator] for field_name, field_type in hints.items():
should_omit = False
if omit_flag:
# Check for omission in the annotation metadata
metadata = _extract_metadata(field_type)
for meta in metadata:
if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
should_omit = True
break
if not should_omit:
all_annotations[field_name] = field_type
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]: def _extract_metadata(type_: type) -> list:
"""Filter state to only include fields defined in the given schema.""" """Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
if not hasattr(schema, "__annotations__"): # Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
return state if get_origin(type_) in (Required, NotRequired):
inner_type = get_args(type_)[0]
if get_origin(inner_type) is Annotated:
return list(get_args(inner_type)[1:])
schema_fields = set(schema.__annotations__.keys()) # Handle direct Annotated[...]
return {k: v for k, v in state.items() if k in schema_fields} elif get_origin(type_) is Annotated:
return list(get_args(type_)[1:])
return []
def _supports_native_structured_output(model: str | BaseChatModel) -> bool: def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
@@ -114,7 +138,7 @@ def create_agent( # noqa: PLR0915
model: str | BaseChatModel, model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None, tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
middleware: Sequence[AgentMiddleware] = (), middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
context_schema: type[ContextT] | None = None, context_schema: type[ContextT] | None = None,
) -> StateGraph[ ) -> StateGraph[
@@ -199,16 +223,16 @@ def create_agent( # noqa: PLR0915
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
] ]
# Collect all middleware state schemas and create merged schema state_schemas = {m.state_schema for m in middleware}
merged_state_schema: type[AgentState] = _merge_state_schemas( state_schemas.add(AgentState)
[m.state_schema for m in middleware]
)
# create graph, add nodes # create graph, add nodes
graph = StateGraph( graph: StateGraph[
merged_state_schema, AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
input_schema=PublicAgentState, ] = StateGraph(
output_schema=PublicAgentState, state_schema=_resolve_schema(state_schemas, "StateSchema", None),
input_schema=_resolve_schema(state_schemas, "InputSchema", "input"),
output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"),
context_schema=context_schema, context_schema=context_schema,
) )
@@ -318,7 +342,14 @@ def create_agent( # noqa: PLR0915
) )
return request.model.bind(**request.model_settings) return request.model.bind(**request.model_settings)
def model_request(state: dict[str, Any]) -> dict[str, Any]: model_request_signatures: list[
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
] = [
("runtime" in signature(m.modify_model_request).parameters, m)
for m in middleware_w_modify_model_request
]
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing.""" """Sync model request handler with sequential middleware processing."""
request = ModelRequest( request = ModelRequest(
model=model, model=model,
@@ -330,10 +361,11 @@ def create_agent( # noqa: PLR0915
) )
# Apply modify_model_request middleware in sequence # Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request: for use_runtime, m in model_request_signatures:
# Filter state to only include fields defined in this middleware's schema if use_runtime:
filtered_state = _filter_state_for_schema(state, m.state_schema) m.modify_model_request(request, state, runtime)
request = m.modify_model_request(request, filtered_state) else:
m.modify_model_request(request, state) # type: ignore[call-arg]
# Get the final model and messages # Get the final model and messages
model_ = _get_bound_model(request) model_ = _get_bound_model(request)
@@ -344,7 +376,7 @@ def create_agent( # noqa: PLR0915
output = model_.invoke(messages) output = model_.invoke(messages)
return _handle_model_output(output) return _handle_model_output(output)
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]: async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing.""" """Async model request handler with sequential middleware processing."""
# Start with the base model request # Start with the base model request
request = ModelRequest( request = ModelRequest(
@@ -357,10 +389,11 @@ def create_agent( # noqa: PLR0915
) )
# Apply modify_model_request middleware in sequence # Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request: for use_runtime, m in model_request_signatures:
# Filter state to only include fields defined in this middleware's schema if use_runtime:
filtered_state = _filter_state_for_schema(state, m.state_schema) m.modify_model_request(request, state, runtime)
request = m.modify_model_request(request, filtered_state) else:
m.modify_model_request(request, state) # type: ignore[call-arg]
# Get the final model and messages # Get the final model and messages
model_ = _get_bound_model(request) model_ = _get_bound_model(request)

View File

@@ -1,9 +1,12 @@
from typing_extensions import TypedDict
import pytest import pytest
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from langgraph.runtime import Runtime
from typing_extensions import Annotated, TypedDict
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
@@ -15,17 +18,23 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
) )
from langchain_core.tools import tool from langchain_core.tools import tool
from langgraph.types import Command
from langchain.agents.middleware_agent import create_agent from langchain.agents.middleware_agent import create_agent
from langchain.agents.middleware.human_in_the_loop import ( from langchain.agents.middleware.human_in_the_loop import (
HumanInTheLoopMiddleware, HumanInTheLoopMiddleware,
HumanInTheLoopConfig,
ActionRequest, ActionRequest,
) )
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState from langchain.agents.middleware.types import (
AgentMiddleware,
ModelRequest,
AgentState,
OmitFromInput,
OmitFromOutput,
PrivateStateAttr,
)
from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware
from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -1201,3 +1210,128 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
assert hasattr(result["response"], "temperature") assert hasattr(result["response"], "temperature")
assert result["response"].temperature == 72.0 assert result["response"].temperature == 72.0
assert result["response"].condition == "sunny" assert result["response"].condition == "sunny"
# Tests for DynamicSystemPromptMiddleware
def test_dynamic_system_prompt_middleware_basic() -> None:
"""Test basic functionality of DynamicSystemPromptMiddleware."""
def dynamic_system_prompt(state: AgentState) -> str:
messages = state.get("messages", [])
if messages:
return f"You are a helpful assistant. Message count: {len(messages)}"
return "You are a helpful assistant. No messages yet."
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
# Test with empty state
empty_state = {"messages": []}
request = ModelRequest(
model=FakeToolCallingModel(),
system_prompt="Original prompt",
messages=[],
tool_choice=None,
tools=[],
response_format=None,
)
modified_request = middleware.modify_model_request(request, empty_state, None)
assert modified_request.system_prompt == "You are a helpful assistant. No messages yet."
state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]}
modified_request = middleware.modify_model_request(request, state_with_messages, None)
assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2"
def test_dynamic_system_prompt_middleware_with_context() -> None:
"""Test DynamicSystemPromptMiddleware with runtime context."""
class MockContext(TypedDict):
user_role: str
def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str:
base_prompt = "You are a helpful assistant."
if runtime and hasattr(runtime, "context"):
user_role = runtime.context.get("user_role", "user")
return f"{base_prompt} User role: {user_role}"
return base_prompt
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
# Create a mock runtime with context
class MockRuntime:
def __init__(self, context):
self.context = context
mock_runtime = MockRuntime(context={"user_role": "admin"})
request = ModelRequest(
model=FakeToolCallingModel(),
system_prompt="Original prompt",
messages=[HumanMessage("Test")],
tool_choice=None,
tools=[],
response_format=None,
)
state = {"messages": [HumanMessage("Test")]}
modified_request = middleware.modify_model_request(request, state, mock_runtime)
assert modified_request.system_prompt == "You are a helpful assistant. User role: admin"
def test_public_private_state_for_custom_middleware() -> None:
"""Test public and private state for custom middleware."""
class CustomState(AgentState):
omit_input: Annotated[str, OmitFromInput]
omit_output: Annotated[str, OmitFromOutput]
private_state: Annotated[str, PrivateStateAttr]
class CustomMiddleware(AgentMiddleware[CustomState]):
state_schema: type[CustomState] = CustomState
def before_model(self, state: CustomState) -> dict[str, Any]:
assert "omit_input" not in state
assert "omit_output" in state
assert "private_state" not in state
return {"omit_input": "test", "omit_output": "test", "private_state": "test"}
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
agent = agent.compile()
result = agent.invoke(
{
"messages": [HumanMessage("Hello")],
"omit_input": "test in",
"private_state": "test in",
"omit_output": "test in",
}
)
assert "omit_input" in result
assert "omit_output" not in result
assert "private_state" not in result
def test_runtime_injected_into_middleware() -> None:
"""Test that the runtime is injected into the middleware."""
class CustomMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime: Runtime) -> None:
assert runtime is not None
return None
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
assert runtime is not None
return request
def after_model(self, state: AgentState, runtime: Runtime) -> None:
assert runtime is not None
return None
middleware = CustomMiddleware()
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
agent = agent.compile()
agent.invoke({"messages": [HumanMessage("Hello")]})