mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-21 18:39:57 +00:00
feat(langchain): dynamic system prompt middleware (#33006)
# Changes ## Adds support for `DynamicSystemPromptMiddleware` ```py from langchain.agents.middleware import DynamicSystemPromptMiddleware from langgraph.runtime import Runtime from typing_extensions import TypedDict class Context(TypedDict): user_name: str def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str: user_name = runtime.context.get("user_name", "n/a") return f"You are a helpful assistant. Always address the user by their name: {user_name}" middleware = DynamicSystemPromptMiddleware(system_prompt) ``` ## Adds support for `runtime` in middleware hooks ```py class AgentMiddleware(Generic[StateT, ContextT]): def modify_model_request( self, request: ModelRequest, state: StateT, runtime: Runtime[ContextT], # Optional runtime parameter ) -> ModelRequest: # upgrade model if runtime.context.subscription is `top-tier` or whatever ``` ## Adds support for omitting state attributes from input / output schemas ```py from typing import Annotated, NotRequired from langchain.agents.middleware.types import PrivateStateAttr, OmitFromInput, OmitFromOutput class CustomState(AgentState): # Private field - not in input or output schemas internal_counter: NotRequired[Annotated[int, PrivateStateAttr]] # Input-only field - not in output schema user_input: NotRequired[Annotated[str, OmitFromOutput]] # Output-only field - not in input schema computed_result: NotRequired[Annotated[str, OmitFromInput]] ``` ## Additionally * Removes filtering of state before passing into middleware hooks Typing is not foolproof here, still need to figure out some of the generics stuff w/ state and context schema extensions for middleware. TODO: * More docs for middleware, should hold off on this until other prios like MCP and deepagents are met --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Middleware plugins for agents."""
|
||||
|
||||
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
|
||||
from .human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
@@ -8,7 +9,9 @@ from .types import AgentMiddleware, AgentState, ModelRequest
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
# should move to langchain-anthropic if we decide to keep it
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"DynamicSystemPromptMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"ModelRequest",
|
||||
"SummarizationMiddleware",
|
||||
|
@@ -0,0 +1,105 @@
|
||||
"""Dynamic System Prompt Middleware.
|
||||
|
||||
Allows setting the system prompt dynamically right before each model invocation.
|
||||
Useful when the prompt depends on the current agent state or per-invocation context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
|
||||
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
class DynamicSystemPromptWithoutRuntime(Protocol):
|
||||
"""Dynamic system prompt without runtime in call signature."""
|
||||
|
||||
def __call__(self, state: AgentState) -> str:
|
||||
"""Return the system prompt for the next model call."""
|
||||
...
|
||||
|
||||
|
||||
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
|
||||
"""Dynamic system prompt with runtime in call signature."""
|
||||
|
||||
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
||||
"""Return the system prompt for the next model call."""
|
||||
...
|
||||
|
||||
|
||||
DynamicSystemPrompt: TypeAlias = (
|
||||
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
|
||||
)
|
||||
|
||||
|
||||
class DynamicSystemPromptMiddleware(AgentMiddleware):
|
||||
"""Dynamic System Prompt Middleware.
|
||||
|
||||
Allows setting the system prompt dynamically right before each model invocation.
|
||||
Useful when the prompt depends on the current agent state or per-invocation context.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware import DynamicSystemPromptMiddleware
|
||||
|
||||
|
||||
class Context(TypedDict):
|
||||
user_name: str
|
||||
|
||||
|
||||
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
|
||||
user_name = runtime.context.get("user_name", "n/a")
|
||||
return (
|
||||
f"You are a helpful assistant. Always address the user by their name: {user_name}"
|
||||
)
|
||||
|
||||
|
||||
middleware = DynamicSystemPromptMiddleware(system_prompt)
|
||||
```
|
||||
"""
|
||||
|
||||
_accepts_runtime: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
|
||||
) -> None:
|
||||
"""Initialize the dynamic system prompt middleware.
|
||||
|
||||
Args:
|
||||
dynamic_system_prompt: Function that receives the current agent state
|
||||
and optionally runtime with context, and returns the system prompt for
|
||||
the next model call. Returns a string.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dynamic_system_prompt = dynamic_system_prompt
|
||||
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
|
||||
|
||||
def modify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: AgentState,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
"""Modify the model request to include the dynamic system prompt."""
|
||||
if self._accepts_runtime:
|
||||
system_prompt = cast(
|
||||
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
|
||||
)(state, runtime)
|
||||
else:
|
||||
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
|
||||
state
|
||||
)
|
||||
|
||||
request.system_prompt = system_prompt
|
||||
return request
|
@@ -143,7 +143,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
self.tool_configs = resolved_tool_configs
|
||||
self.description_prefix = description_prefix
|
||||
|
||||
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
||||
def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
|
||||
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
@@ -32,7 +32,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
self.ttl = ttl
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002
|
||||
def modify_model_request( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest,
|
||||
) -> ModelRequest:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
@@ -98,7 +98,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self.summary_prompt = summary_prompt
|
||||
self.summary_prefix = summary_prefix
|
||||
|
||||
def before_model(self, state: AgentState) -> dict[str, Any] | None:
|
||||
def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
|
||||
"""Process messages before model invocation, potentially triggering summarization."""
|
||||
messages = state["messages"]
|
||||
self._ensure_message_ids(messages)
|
||||
|
@@ -8,15 +8,27 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.graph.message import Messages, add_messages
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.structured_output import ResponseFormat
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"ContextT",
|
||||
"ModelRequest",
|
||||
"OmitFromSchema",
|
||||
"PublicAgentState",
|
||||
]
|
||||
|
||||
JumpTo = Literal["tools", "model", "__end__"]
|
||||
"""Destination to jump to when a middleware node returns."""
|
||||
|
||||
@@ -36,26 +48,49 @@ class ModelRequest:
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmitFromSchema:
|
||||
"""Annotation used to mark state attributes as omitted from input or output schemas."""
|
||||
|
||||
input: bool = True
|
||||
"""Whether to omit the attribute from the input schema."""
|
||||
|
||||
output: bool = True
|
||||
"""Whether to omit the attribute from the output schema."""
|
||||
|
||||
|
||||
OmitFromInput = OmitFromSchema(input=True, output=False)
|
||||
"""Annotation used to mark state attributes as omitted from input schema."""
|
||||
|
||||
OmitFromOutput = OmitFromSchema(input=False, output=True)
|
||||
"""Annotation used to mark state attributes as omitted from output schema."""
|
||||
|
||||
PrivateStateAttr = OmitFromSchema(input=True, output=True)
|
||||
"""Annotation used to mark state attributes as purely internal for a given middleware."""
|
||||
|
||||
|
||||
class AgentState(TypedDict, Generic[ResponseT]):
|
||||
"""State schema for the agent."""
|
||||
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]]
|
||||
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
|
||||
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
||||
response: NotRequired[ResponseT]
|
||||
|
||||
|
||||
class PublicAgentState(TypedDict, Generic[ResponseT]):
|
||||
"""Input / output schema for the agent."""
|
||||
"""Public state schema for the agent.
|
||||
|
||||
messages: Required[Messages]
|
||||
Just used for typing purposes.
|
||||
"""
|
||||
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
response: NotRequired[ResponseT]
|
||||
|
||||
|
||||
StateT = TypeVar("StateT", bound=AgentState)
|
||||
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
||||
|
||||
|
||||
class AgentMiddleware(Generic[StateT]):
|
||||
class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
"""Base middleware class for an agent.
|
||||
|
||||
Subclass this and implement any of the defined methods to customize agent behavior
|
||||
@@ -68,12 +103,17 @@ class AgentMiddleware(Generic[StateT]):
|
||||
tools: list[BaseTool]
|
||||
"""Additional tools registered by the middleware."""
|
||||
|
||||
def before_model(self, state: StateT) -> dict[str, Any] | None:
|
||||
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run before the model is called."""
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002
|
||||
def modify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: StateT, # noqa: ARG002
|
||||
runtime: Runtime[ContextT], # noqa: ARG002
|
||||
) -> ModelRequest:
|
||||
"""Logic to modify request kwargs before the model is called."""
|
||||
return request
|
||||
|
||||
def after_model(self, state: StateT) -> dict[str, Any] | None:
|
||||
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run after the model is called."""
|
||||
|
@@ -2,7 +2,8 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, cast
|
||||
from inspect import signature
|
||||
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
@@ -10,19 +11,19 @@ from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Send
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import TypedDict, TypeVar
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
JumpTo,
|
||||
ModelRequest,
|
||||
OmitFromSchema,
|
||||
PublicAgentState,
|
||||
)
|
||||
|
||||
# Import structured output classes from the old implementation
|
||||
from langchain.agents.structured_output import (
|
||||
MultipleStructuredOutputsError,
|
||||
OutputToolBinding,
|
||||
@@ -38,26 +39,49 @@ from langchain.chat_models import init_chat_model
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
|
||||
def _merge_state_schemas(schemas: list[type]) -> type:
|
||||
"""Merge multiple TypedDict schemas into a single schema with all fields."""
|
||||
if not schemas:
|
||||
return AgentState
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
||||
|
||||
Args:
|
||||
schemas: List of schema types to merge
|
||||
schema_name: Name for the generated TypedDict
|
||||
omit_flag: If specified, omit fields with this flag set ('input' or 'output')
|
||||
"""
|
||||
all_annotations = {}
|
||||
|
||||
for schema in schemas:
|
||||
all_annotations.update(schema.__annotations__)
|
||||
hints = get_type_hints(schema, include_extras=True)
|
||||
|
||||
return TypedDict("MergedState", all_annotations) # type: ignore[operator]
|
||||
for field_name, field_type in hints.items():
|
||||
should_omit = False
|
||||
|
||||
if omit_flag:
|
||||
# Check for omission in the annotation metadata
|
||||
metadata = _extract_metadata(field_type)
|
||||
for meta in metadata:
|
||||
if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
|
||||
should_omit = True
|
||||
break
|
||||
|
||||
if not should_omit:
|
||||
all_annotations[field_name] = field_type
|
||||
|
||||
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
|
||||
|
||||
|
||||
def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]:
|
||||
"""Filter state to only include fields defined in the given schema."""
|
||||
if not hasattr(schema, "__annotations__"):
|
||||
return state
|
||||
def _extract_metadata(type_: type) -> list:
|
||||
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
|
||||
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
|
||||
if get_origin(type_) in (Required, NotRequired):
|
||||
inner_type = get_args(type_)[0]
|
||||
if get_origin(inner_type) is Annotated:
|
||||
return list(get_args(inner_type)[1:])
|
||||
|
||||
schema_fields = set(schema.__annotations__.keys())
|
||||
return {k: v for k, v in state.items() if k in schema_fields}
|
||||
# Handle direct Annotated[...]
|
||||
elif get_origin(type_) is Annotated:
|
||||
return list(get_args(type_)[1:])
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
|
||||
@@ -114,7 +138,7 @@ def create_agent( # noqa: PLR0915
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
||||
system_prompt: str | None = None,
|
||||
middleware: Sequence[AgentMiddleware] = (),
|
||||
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
) -> StateGraph[
|
||||
@@ -199,16 +223,16 @@ def create_agent( # noqa: PLR0915
|
||||
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
]
|
||||
|
||||
# Collect all middleware state schemas and create merged schema
|
||||
merged_state_schema: type[AgentState] = _merge_state_schemas(
|
||||
[m.state_schema for m in middleware]
|
||||
)
|
||||
state_schemas = {m.state_schema for m in middleware}
|
||||
state_schemas.add(AgentState)
|
||||
|
||||
# create graph, add nodes
|
||||
graph = StateGraph(
|
||||
merged_state_schema,
|
||||
input_schema=PublicAgentState,
|
||||
output_schema=PublicAgentState,
|
||||
graph: StateGraph[
|
||||
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
||||
] = StateGraph(
|
||||
state_schema=_resolve_schema(state_schemas, "StateSchema", None),
|
||||
input_schema=_resolve_schema(state_schemas, "InputSchema", "input"),
|
||||
output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"),
|
||||
context_schema=context_schema,
|
||||
)
|
||||
|
||||
@@ -318,7 +342,14 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
return request.model.bind(**request.model_settings)
|
||||
|
||||
def model_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||
model_request_signatures: list[
|
||||
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
|
||||
] = [
|
||||
("runtime" in signature(m.modify_model_request).parameters, m)
|
||||
for m in middleware_w_modify_model_request
|
||||
]
|
||||
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
@@ -330,10 +361,11 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
# Filter state to only include fields defined in this middleware's schema
|
||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||
request = m.modify_model_request(request, filtered_state)
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
@@ -344,7 +376,7 @@ def create_agent( # noqa: PLR0915
|
||||
output = model_.invoke(messages)
|
||||
return _handle_model_output(output)
|
||||
|
||||
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
# Start with the base model request
|
||||
request = ModelRequest(
|
||||
@@ -357,10 +389,11 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
# Filter state to only include fields defined in this middleware's schema
|
||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||
request = m.modify_model_request(request, filtered_state)
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
|
@@ -1,9 +1,12 @@
|
||||
from typing_extensions import TypedDict
|
||||
import pytest
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@@ -15,17 +18,23 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
HumanInTheLoopConfig,
|
||||
ActionRequest,
|
||||
)
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
AgentState,
|
||||
OmitFromInput,
|
||||
OmitFromOutput,
|
||||
PrivateStateAttr,
|
||||
)
|
||||
from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
@@ -1201,3 +1210,128 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
|
||||
assert hasattr(result["response"], "temperature")
|
||||
assert result["response"].temperature == 72.0
|
||||
assert result["response"].condition == "sunny"
|
||||
|
||||
|
||||
# Tests for DynamicSystemPromptMiddleware
|
||||
def test_dynamic_system_prompt_middleware_basic() -> None:
|
||||
"""Test basic functionality of DynamicSystemPromptMiddleware."""
|
||||
|
||||
def dynamic_system_prompt(state: AgentState) -> str:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
return f"You are a helpful assistant. Message count: {len(messages)}"
|
||||
return "You are a helpful assistant. No messages yet."
|
||||
|
||||
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
|
||||
|
||||
# Test with empty state
|
||||
empty_state = {"messages": []}
|
||||
request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
system_prompt="Original prompt",
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
modified_request = middleware.modify_model_request(request, empty_state, None)
|
||||
assert modified_request.system_prompt == "You are a helpful assistant. No messages yet."
|
||||
|
||||
state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]}
|
||||
modified_request = middleware.modify_model_request(request, state_with_messages, None)
|
||||
assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2"
|
||||
|
||||
|
||||
def test_dynamic_system_prompt_middleware_with_context() -> None:
|
||||
"""Test DynamicSystemPromptMiddleware with runtime context."""
|
||||
|
||||
class MockContext(TypedDict):
|
||||
user_role: str
|
||||
|
||||
def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str:
|
||||
base_prompt = "You are a helpful assistant."
|
||||
if runtime and hasattr(runtime, "context"):
|
||||
user_role = runtime.context.get("user_role", "user")
|
||||
return f"{base_prompt} User role: {user_role}"
|
||||
return base_prompt
|
||||
|
||||
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
|
||||
|
||||
# Create a mock runtime with context
|
||||
class MockRuntime:
|
||||
def __init__(self, context):
|
||||
self.context = context
|
||||
|
||||
mock_runtime = MockRuntime(context={"user_role": "admin"})
|
||||
|
||||
request = ModelRequest(
|
||||
model=FakeToolCallingModel(),
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage("Test")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
|
||||
state = {"messages": [HumanMessage("Test")]}
|
||||
modified_request = middleware.modify_model_request(request, state, mock_runtime)
|
||||
|
||||
assert modified_request.system_prompt == "You are a helpful assistant. User role: admin"
|
||||
|
||||
|
||||
def test_public_private_state_for_custom_middleware() -> None:
|
||||
"""Test public and private state for custom middleware."""
|
||||
|
||||
class CustomState(AgentState):
|
||||
omit_input: Annotated[str, OmitFromInput]
|
||||
omit_output: Annotated[str, OmitFromOutput]
|
||||
private_state: Annotated[str, PrivateStateAttr]
|
||||
|
||||
class CustomMiddleware(AgentMiddleware[CustomState]):
|
||||
state_schema: type[CustomState] = CustomState
|
||||
|
||||
def before_model(self, state: CustomState) -> dict[str, Any]:
|
||||
assert "omit_input" not in state
|
||||
assert "omit_output" in state
|
||||
assert "private_state" not in state
|
||||
return {"omit_input": "test", "omit_output": "test", "private_state": "test"}
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
|
||||
agent = agent.compile()
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Hello")],
|
||||
"omit_input": "test in",
|
||||
"private_state": "test in",
|
||||
"omit_output": "test in",
|
||||
}
|
||||
)
|
||||
assert "omit_input" in result
|
||||
assert "omit_output" not in result
|
||||
assert "private_state" not in result
|
||||
|
||||
|
||||
def test_runtime_injected_into_middleware() -> None:
|
||||
"""Test that the runtime is injected into the middleware."""
|
||||
|
||||
class CustomMiddleware(AgentMiddleware):
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> None:
|
||||
assert runtime is not None
|
||||
return None
|
||||
|
||||
def modify_model_request(
|
||||
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||
) -> ModelRequest:
|
||||
assert runtime is not None
|
||||
return request
|
||||
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> None:
|
||||
assert runtime is not None
|
||||
return None
|
||||
|
||||
middleware = CustomMiddleware()
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
|
||||
agent = agent.compile()
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
Reference in New Issue
Block a user