feat(langchain): dynamic system prompt middleware (#33006)

# Changes

## Adds support for `DynamicSystemPromptMiddleware`

```py
from langchain.agents.middleware import DynamicSystemPromptMiddleware
from langgraph.runtime import Runtime
from typing_extensions import TypedDict

class Context(TypedDict):
    user_name: str

def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
    user_name = runtime.context.get("user_name", "n/a")
    return f"You are a helpful assistant. Always address the user by their name: {user_name}"

middleware = DynamicSystemPromptMiddleware(system_prompt)
```

## Adds support for `runtime` in middleware hooks

```py
class AgentMiddleware(Generic[StateT, ContextT]):
    def modify_model_request(
        self,
        request: ModelRequest,
        state: StateT,
        runtime: Runtime[ContextT],  # Optional runtime parameter
    ) -> ModelRequest:
        # upgrade model if runtime.context.subscription is `top-tier` or whatever
```

## Adds support for omitting state attributes from input / output
schemas

```py
from typing import Annotated, NotRequired
from langchain.agents.middleware.types import PrivateStateAttr, OmitFromInput, OmitFromOutput

class CustomState(AgentState):
    # Private field - not in input or output schemas
    internal_counter: NotRequired[Annotated[int, PrivateStateAttr]]
    
    # Input-only field - not in output schema
    user_input: NotRequired[Annotated[str, OmitFromOutput]]
    
    # Output-only field - not in input schema  
    computed_result: NotRequired[Annotated[str, OmitFromInput]]
```

## Additionally
* Removes filtering of state before passing into middleware hooks

Typing is not foolproof here, still need to figure out some of the
generics stuff w/ state and context schema extensions for middleware.

TODO:
* More docs for middleware, should hold off on this until other prios
like MCP and deepagents are met

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Sydney Runkle
2025-09-18 16:07:16 -04:00
committed by GitHub
parent f158cea1e8
commit 4d118777bc
8 changed files with 370 additions and 52 deletions

View File

@@ -1,5 +1,6 @@
"""Middleware plugins for agents."""
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
from .human_in_the_loop import HumanInTheLoopMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
@@ -8,7 +9,9 @@ from .types import AgentMiddleware, AgentState, ModelRequest
__all__ = [
"AgentMiddleware",
"AgentState",
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"DynamicSystemPromptMiddleware",
"HumanInTheLoopMiddleware",
"ModelRequest",
"SummarizationMiddleware",

View File

@@ -0,0 +1,105 @@
"""Dynamic System Prompt Middleware.
Allows setting the system prompt dynamically right before each model invocation.
Useful when the prompt depends on the current agent state or per-invocation context.
"""
from __future__ import annotations
from inspect import signature
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
from langgraph.typing import ContextT
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
)
if TYPE_CHECKING:
from langgraph.runtime import Runtime
class DynamicSystemPromptWithoutRuntime(Protocol):
"""Dynamic system prompt without runtime in call signature."""
def __call__(self, state: AgentState) -> str:
"""Return the system prompt for the next model call."""
...
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
"""Dynamic system prompt with runtime in call signature."""
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
"""Return the system prompt for the next model call."""
...
DynamicSystemPrompt: TypeAlias = (
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
)
class DynamicSystemPromptMiddleware(AgentMiddleware):
"""Dynamic System Prompt Middleware.
Allows setting the system prompt dynamically right before each model invocation.
Useful when the prompt depends on the current agent state or per-invocation context.
Example:
```python
from langchain.agents.middleware import DynamicSystemPromptMiddleware
class Context(TypedDict):
user_name: str
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
user_name = runtime.context.get("user_name", "n/a")
return (
f"You are a helpful assistant. Always address the user by their name: {user_name}"
)
middleware = DynamicSystemPromptMiddleware(system_prompt)
```
"""
_accepts_runtime: bool
def __init__(
self,
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
) -> None:
"""Initialize the dynamic system prompt middleware.
Args:
dynamic_system_prompt: Function that receives the current agent state
and optionally runtime with context, and returns the system prompt for
the next model call. Returns a string.
"""
super().__init__()
self.dynamic_system_prompt = dynamic_system_prompt
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
def modify_model_request(
self,
request: ModelRequest,
state: AgentState,
runtime: Runtime[ContextT],
) -> ModelRequest:
"""Modify the model request to include the dynamic system prompt."""
if self._accepts_runtime:
system_prompt = cast(
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
)(state, runtime)
else:
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
state
)
request.system_prompt = system_prompt
return request

View File

@@ -143,7 +143,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
self.tool_configs = resolved_tool_configs
self.description_prefix = description_prefix
def after_model(self, state: AgentState) -> dict[str, Any] | None:
def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
messages = state["messages"]
if not messages:

View File

@@ -2,7 +2,7 @@
from typing import Literal
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
class AnthropicPromptCachingMiddleware(AgentMiddleware):
@@ -32,7 +32,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002
def modify_model_request( # type: ignore[override]
self,
request: ModelRequest,
) -> ModelRequest:
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic

View File

@@ -98,7 +98,7 @@ class SummarizationMiddleware(AgentMiddleware):
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
def before_model(self, state: AgentState) -> dict[str, Any] | None:
def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)

View File

@@ -8,15 +8,27 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import Messages, add_messages
from langgraph.graph.message import add_messages
from langgraph.runtime import Runtime
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from langchain.agents.structured_output import ResponseFormat
__all__ = [
"AgentMiddleware",
"AgentState",
"ContextT",
"ModelRequest",
"OmitFromSchema",
"PublicAgentState",
]
JumpTo = Literal["tools", "model", "__end__"]
"""Destination to jump to when a middleware node returns."""
@@ -36,26 +48,49 @@ class ModelRequest:
model_settings: dict[str, Any] = field(default_factory=dict)
@dataclass
class OmitFromSchema:
"""Annotation used to mark state attributes as omitted from input or output schemas."""
input: bool = True
"""Whether to omit the attribute from the input schema."""
output: bool = True
"""Whether to omit the attribute from the output schema."""
OmitFromInput = OmitFromSchema(input=True, output=False)
"""Annotation used to mark state attributes as omitted from input schema."""
OmitFromOutput = OmitFromSchema(input=False, output=True)
"""Annotation used to mark state attributes as omitted from output schema."""
PrivateStateAttr = OmitFromSchema(input=True, output=True)
"""Annotation used to mark state attributes as purely internal for a given middleware."""
class AgentState(TypedDict, Generic[ResponseT]):
"""State schema for the agent."""
messages: Required[Annotated[list[AnyMessage], add_messages]]
model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]]
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
response: NotRequired[ResponseT]
class PublicAgentState(TypedDict, Generic[ResponseT]):
"""Input / output schema for the agent."""
"""Public state schema for the agent.
messages: Required[Messages]
Just used for typing purposes.
"""
messages: Required[Annotated[list[AnyMessage], add_messages]]
response: NotRequired[ResponseT]
StateT = TypeVar("StateT", bound=AgentState)
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
class AgentMiddleware(Generic[StateT]):
class AgentMiddleware(Generic[StateT, ContextT]):
"""Base middleware class for an agent.
Subclass this and implement any of the defined methods to customize agent behavior
@@ -68,12 +103,17 @@ class AgentMiddleware(Generic[StateT]):
tools: list[BaseTool]
"""Additional tools registered by the middleware."""
def before_model(self, state: StateT) -> dict[str, Any] | None:
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called."""
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002
def modify_model_request(
self,
request: ModelRequest,
state: StateT, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
) -> ModelRequest:
"""Logic to modify request kwargs before the model is called."""
return request
def after_model(self, state: StateT) -> dict[str, Any] | None:
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the model is called."""

View File

@@ -2,7 +2,8 @@
import itertools
from collections.abc import Callable, Sequence
from typing import Any, cast
from inspect import signature
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
@@ -10,19 +11,19 @@ from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.runtime import Runtime
from langgraph.types import Send
from langgraph.typing import ContextT
from typing_extensions import TypedDict, TypeVar
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
JumpTo,
ModelRequest,
OmitFromSchema,
PublicAgentState,
)
# Import structured output classes from the old implementation
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
OutputToolBinding,
@@ -38,26 +39,49 @@ from langchain.chat_models import init_chat_model
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
def _merge_state_schemas(schemas: list[type]) -> type:
"""Merge multiple TypedDict schemas into a single schema with all fields."""
if not schemas:
return AgentState
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
Args:
schemas: List of schema types to merge
schema_name: Name for the generated TypedDict
omit_flag: If specified, omit fields with this flag set ('input' or 'output')
"""
all_annotations = {}
for schema in schemas:
all_annotations.update(schema.__annotations__)
hints = get_type_hints(schema, include_extras=True)
return TypedDict("MergedState", all_annotations) # type: ignore[operator]
for field_name, field_type in hints.items():
should_omit = False
if omit_flag:
# Check for omission in the annotation metadata
metadata = _extract_metadata(field_type)
for meta in metadata:
if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
should_omit = True
break
if not should_omit:
all_annotations[field_name] = field_type
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]:
"""Filter state to only include fields defined in the given schema."""
if not hasattr(schema, "__annotations__"):
return state
def _extract_metadata(type_: type) -> list:
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
if get_origin(type_) in (Required, NotRequired):
inner_type = get_args(type_)[0]
if get_origin(inner_type) is Annotated:
return list(get_args(inner_type)[1:])
schema_fields = set(schema.__annotations__.keys())
return {k: v for k, v in state.items() if k in schema_fields}
# Handle direct Annotated[...]
elif get_origin(type_) is Annotated:
return list(get_args(type_)[1:])
return []
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
@@ -114,7 +138,7 @@ def create_agent( # noqa: PLR0915
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
system_prompt: str | None = None,
middleware: Sequence[AgentMiddleware] = (),
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
context_schema: type[ContextT] | None = None,
) -> StateGraph[
@@ -199,16 +223,16 @@ def create_agent( # noqa: PLR0915
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
]
# Collect all middleware state schemas and create merged schema
merged_state_schema: type[AgentState] = _merge_state_schemas(
[m.state_schema for m in middleware]
)
state_schemas = {m.state_schema for m in middleware}
state_schemas.add(AgentState)
# create graph, add nodes
graph = StateGraph(
merged_state_schema,
input_schema=PublicAgentState,
output_schema=PublicAgentState,
graph: StateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
] = StateGraph(
state_schema=_resolve_schema(state_schemas, "StateSchema", None),
input_schema=_resolve_schema(state_schemas, "InputSchema", "input"),
output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"),
context_schema=context_schema,
)
@@ -318,7 +342,14 @@ def create_agent( # noqa: PLR0915
)
return request.model.bind(**request.model_settings)
def model_request(state: dict[str, Any]) -> dict[str, Any]:
model_request_signatures: list[
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
] = [
("runtime" in signature(m.modify_model_request).parameters, m)
for m in middleware_w_modify_model_request
]
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
@@ -330,10 +361,11 @@ def create_agent( # noqa: PLR0915
)
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
# Filter state to only include fields defined in this middleware's schema
filtered_state = _filter_state_for_schema(state, m.state_schema)
request = m.modify_model_request(request, filtered_state)
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
# Get the final model and messages
model_ = _get_bound_model(request)
@@ -344,7 +376,7 @@ def create_agent( # noqa: PLR0915
output = model_.invoke(messages)
return _handle_model_output(output)
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
# Start with the base model request
request = ModelRequest(
@@ -357,10 +389,11 @@ def create_agent( # noqa: PLR0915
)
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
# Filter state to only include fields defined in this middleware's schema
filtered_state = _filter_state_for_schema(state, m.state_schema)
request = m.modify_model_request(request, filtered_state)
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
# Get the final model and messages
model_ = _get_bound_model(request)

View File

@@ -1,9 +1,12 @@
from typing_extensions import TypedDict
import pytest
from typing import Any
from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion
from langgraph.runtime import Runtime
from typing_extensions import Annotated, TypedDict
from pydantic import BaseModel, Field
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel
@@ -15,17 +18,23 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.tools import tool
from langgraph.types import Command
from langchain.agents.middleware_agent import create_agent
from langchain.agents.middleware.human_in_the_loop import (
HumanInTheLoopMiddleware,
HumanInTheLoopConfig,
ActionRequest,
)
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelRequest,
AgentState,
OmitFromInput,
OmitFromOutput,
PrivateStateAttr,
)
from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
@@ -1201,3 +1210,128 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
assert hasattr(result["response"], "temperature")
assert result["response"].temperature == 72.0
assert result["response"].condition == "sunny"
# Tests for DynamicSystemPromptMiddleware
def test_dynamic_system_prompt_middleware_basic() -> None:
"""Test basic functionality of DynamicSystemPromptMiddleware."""
def dynamic_system_prompt(state: AgentState) -> str:
messages = state.get("messages", [])
if messages:
return f"You are a helpful assistant. Message count: {len(messages)}"
return "You are a helpful assistant. No messages yet."
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
# Test with empty state
empty_state = {"messages": []}
request = ModelRequest(
model=FakeToolCallingModel(),
system_prompt="Original prompt",
messages=[],
tool_choice=None,
tools=[],
response_format=None,
)
modified_request = middleware.modify_model_request(request, empty_state, None)
assert modified_request.system_prompt == "You are a helpful assistant. No messages yet."
state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]}
modified_request = middleware.modify_model_request(request, state_with_messages, None)
assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2"
def test_dynamic_system_prompt_middleware_with_context() -> None:
"""Test DynamicSystemPromptMiddleware with runtime context."""
class MockContext(TypedDict):
user_role: str
def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str:
base_prompt = "You are a helpful assistant."
if runtime and hasattr(runtime, "context"):
user_role = runtime.context.get("user_role", "user")
return f"{base_prompt} User role: {user_role}"
return base_prompt
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
# Create a mock runtime with context
class MockRuntime:
def __init__(self, context):
self.context = context
mock_runtime = MockRuntime(context={"user_role": "admin"})
request = ModelRequest(
model=FakeToolCallingModel(),
system_prompt="Original prompt",
messages=[HumanMessage("Test")],
tool_choice=None,
tools=[],
response_format=None,
)
state = {"messages": [HumanMessage("Test")]}
modified_request = middleware.modify_model_request(request, state, mock_runtime)
assert modified_request.system_prompt == "You are a helpful assistant. User role: admin"
def test_public_private_state_for_custom_middleware() -> None:
"""Test public and private state for custom middleware."""
class CustomState(AgentState):
omit_input: Annotated[str, OmitFromInput]
omit_output: Annotated[str, OmitFromOutput]
private_state: Annotated[str, PrivateStateAttr]
class CustomMiddleware(AgentMiddleware[CustomState]):
state_schema: type[CustomState] = CustomState
def before_model(self, state: CustomState) -> dict[str, Any]:
assert "omit_input" not in state
assert "omit_output" in state
assert "private_state" not in state
return {"omit_input": "test", "omit_output": "test", "private_state": "test"}
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
agent = agent.compile()
result = agent.invoke(
{
"messages": [HumanMessage("Hello")],
"omit_input": "test in",
"private_state": "test in",
"omit_output": "test in",
}
)
assert "omit_input" in result
assert "omit_output" not in result
assert "private_state" not in result
def test_runtime_injected_into_middleware() -> None:
"""Test that the runtime is injected into the middleware."""
class CustomMiddleware(AgentMiddleware):
def before_model(self, state: AgentState, runtime: Runtime) -> None:
assert runtime is not None
return None
def modify_model_request(
self, request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
assert runtime is not None
return request
def after_model(self, state: AgentState, runtime: Runtime) -> None:
assert runtime is not None
return None
middleware = CustomMiddleware()
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
agent = agent.compile()
agent.invoke({"messages": [HumanMessage("Hello")]})