mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-22 02:50:31 +00:00
feat(langchain): dynamic system prompt middleware (#33006)
# Changes ## Adds support for `DynamicSystemPromptMiddleware` ```py from langchain.agents.middleware import DynamicSystemPromptMiddleware from langgraph.runtime import Runtime from typing_extensions import TypedDict class Context(TypedDict): user_name: str def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str: user_name = runtime.context.get("user_name", "n/a") return f"You are a helpful assistant. Always address the user by their name: {user_name}" middleware = DynamicSystemPromptMiddleware(system_prompt) ``` ## Adds support for `runtime` in middleware hooks ```py class AgentMiddleware(Generic[StateT, ContextT]): def modify_model_request( self, request: ModelRequest, state: StateT, runtime: Runtime[ContextT], # Optional runtime parameter ) -> ModelRequest: # upgrade model if runtime.context.subscription is `top-tier` or whatever ``` ## Adds support for omitting state attributes from input / output schemas ```py from typing import Annotated, NotRequired from langchain.agents.middleware.types import PrivateStateAttr, OmitFromInput, OmitFromOutput class CustomState(AgentState): # Private field - not in input or output schemas internal_counter: NotRequired[Annotated[int, PrivateStateAttr]] # Input-only field - not in output schema user_input: NotRequired[Annotated[str, OmitFromOutput]] # Output-only field - not in input schema computed_result: NotRequired[Annotated[str, OmitFromInput]] ``` ## Additionally * Removes filtering of state before passing into middleware hooks Typing is not foolproof here, still need to figure out some of the generics stuff w/ state and context schema extensions for middleware. TODO: * More docs for middleware, should hold off on this until other prios like MCP and deepagents are met --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""Middleware plugins for agents."""
|
"""Middleware plugins for agents."""
|
||||||
|
|
||||||
|
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
|
||||||
from .human_in_the_loop import HumanInTheLoopMiddleware
|
from .human_in_the_loop import HumanInTheLoopMiddleware
|
||||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||||
from .summarization import SummarizationMiddleware
|
from .summarization import SummarizationMiddleware
|
||||||
@@ -8,7 +9,9 @@ from .types import AgentMiddleware, AgentState, ModelRequest
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentMiddleware",
|
"AgentMiddleware",
|
||||||
"AgentState",
|
"AgentState",
|
||||||
|
# should move to langchain-anthropic if we decide to keep it
|
||||||
"AnthropicPromptCachingMiddleware",
|
"AnthropicPromptCachingMiddleware",
|
||||||
|
"DynamicSystemPromptMiddleware",
|
||||||
"HumanInTheLoopMiddleware",
|
"HumanInTheLoopMiddleware",
|
||||||
"ModelRequest",
|
"ModelRequest",
|
||||||
"SummarizationMiddleware",
|
"SummarizationMiddleware",
|
||||||
|
@@ -0,0 +1,105 @@
|
|||||||
|
"""Dynamic System Prompt Middleware.
|
||||||
|
|
||||||
|
Allows setting the system prompt dynamically right before each model invocation.
|
||||||
|
Useful when the prompt depends on the current agent state or per-invocation context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from inspect import signature
|
||||||
|
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
|
||||||
|
|
||||||
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ModelRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicSystemPromptWithoutRuntime(Protocol):
|
||||||
|
"""Dynamic system prompt without runtime in call signature."""
|
||||||
|
|
||||||
|
def __call__(self, state: AgentState) -> str:
|
||||||
|
"""Return the system prompt for the next model call."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
|
||||||
|
"""Dynamic system prompt with runtime in call signature."""
|
||||||
|
|
||||||
|
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
||||||
|
"""Return the system prompt for the next model call."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
DynamicSystemPrompt: TypeAlias = (
|
||||||
|
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicSystemPromptMiddleware(AgentMiddleware):
|
||||||
|
"""Dynamic System Prompt Middleware.
|
||||||
|
|
||||||
|
Allows setting the system prompt dynamically right before each model invocation.
|
||||||
|
Useful when the prompt depends on the current agent state or per-invocation context.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from langchain.agents.middleware import DynamicSystemPromptMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class Context(TypedDict):
|
||||||
|
user_name: str
|
||||||
|
|
||||||
|
|
||||||
|
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
|
||||||
|
user_name = runtime.context.get("user_name", "n/a")
|
||||||
|
return (
|
||||||
|
f"You are a helpful assistant. Always address the user by their name: {user_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
middleware = DynamicSystemPromptMiddleware(system_prompt)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
_accepts_runtime: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the dynamic system prompt middleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dynamic_system_prompt: Function that receives the current agent state
|
||||||
|
and optionally runtime with context, and returns the system prompt for
|
||||||
|
the next model call. Returns a string.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dynamic_system_prompt = dynamic_system_prompt
|
||||||
|
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
|
||||||
|
|
||||||
|
def modify_model_request(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> ModelRequest:
|
||||||
|
"""Modify the model request to include the dynamic system prompt."""
|
||||||
|
if self._accepts_runtime:
|
||||||
|
system_prompt = cast(
|
||||||
|
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
|
||||||
|
)(state, runtime)
|
||||||
|
else:
|
||||||
|
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
|
||||||
|
state
|
||||||
|
)
|
||||||
|
|
||||||
|
request.system_prompt = system_prompt
|
||||||
|
return request
|
@@ -143,7 +143,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|||||||
self.tool_configs = resolved_tool_configs
|
self.tool_configs = resolved_tool_configs
|
||||||
self.description_prefix = description_prefix
|
self.description_prefix = description_prefix
|
||||||
|
|
||||||
def after_model(self, state: AgentState) -> dict[str, Any] | None:
|
def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
|
||||||
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
"""Trigger HITL flows for relevant tool calls after an AIMessage."""
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
if not messages:
|
if not messages:
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||||
|
|
||||||
|
|
||||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||||
@@ -32,7 +32,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|||||||
self.ttl = ttl
|
self.ttl = ttl
|
||||||
self.min_messages_to_cache = min_messages_to_cache
|
self.min_messages_to_cache = min_messages_to_cache
|
||||||
|
|
||||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002
|
def modify_model_request( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
) -> ModelRequest:
|
||||||
"""Modify the model request to add cache control blocks."""
|
"""Modify the model request to add cache control blocks."""
|
||||||
try:
|
try:
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
@@ -98,7 +98,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
|||||||
self.summary_prompt = summary_prompt
|
self.summary_prompt = summary_prompt
|
||||||
self.summary_prefix = summary_prefix
|
self.summary_prefix = summary_prefix
|
||||||
|
|
||||||
def before_model(self, state: AgentState) -> dict[str, Any] | None:
|
def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
|
||||||
"""Process messages before model invocation, potentially triggering summarization."""
|
"""Process messages before model invocation, potentially triggering summarization."""
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
self._ensure_message_ids(messages)
|
self._ensure_message_ids(messages)
|
||||||
|
@@ -8,15 +8,27 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
|
|||||||
# needed as top level import for pydantic schema generation on AgentState
|
# needed as top level import for pydantic schema generation on AgentState
|
||||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||||
from langgraph.graph.message import Messages, add_messages
|
from langgraph.graph.message import add_messages
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from langgraph.typing import ContextT
|
||||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from langchain.agents.structured_output import ResponseFormat
|
from langchain.agents.structured_output import ResponseFormat
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentMiddleware",
|
||||||
|
"AgentState",
|
||||||
|
"ContextT",
|
||||||
|
"ModelRequest",
|
||||||
|
"OmitFromSchema",
|
||||||
|
"PublicAgentState",
|
||||||
|
]
|
||||||
|
|
||||||
JumpTo = Literal["tools", "model", "__end__"]
|
JumpTo = Literal["tools", "model", "__end__"]
|
||||||
"""Destination to jump to when a middleware node returns."""
|
"""Destination to jump to when a middleware node returns."""
|
||||||
|
|
||||||
@@ -36,26 +48,49 @@ class ModelRequest:
|
|||||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OmitFromSchema:
|
||||||
|
"""Annotation used to mark state attributes as omitted from input or output schemas."""
|
||||||
|
|
||||||
|
input: bool = True
|
||||||
|
"""Whether to omit the attribute from the input schema."""
|
||||||
|
|
||||||
|
output: bool = True
|
||||||
|
"""Whether to omit the attribute from the output schema."""
|
||||||
|
|
||||||
|
|
||||||
|
OmitFromInput = OmitFromSchema(input=True, output=False)
|
||||||
|
"""Annotation used to mark state attributes as omitted from input schema."""
|
||||||
|
|
||||||
|
OmitFromOutput = OmitFromSchema(input=False, output=True)
|
||||||
|
"""Annotation used to mark state attributes as omitted from output schema."""
|
||||||
|
|
||||||
|
PrivateStateAttr = OmitFromSchema(input=True, output=True)
|
||||||
|
"""Annotation used to mark state attributes as purely internal for a given middleware."""
|
||||||
|
|
||||||
|
|
||||||
class AgentState(TypedDict, Generic[ResponseT]):
|
class AgentState(TypedDict, Generic[ResponseT]):
|
||||||
"""State schema for the agent."""
|
"""State schema for the agent."""
|
||||||
|
|
||||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||||
model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]]
|
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
||||||
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
|
|
||||||
response: NotRequired[ResponseT]
|
response: NotRequired[ResponseT]
|
||||||
|
|
||||||
|
|
||||||
class PublicAgentState(TypedDict, Generic[ResponseT]):
|
class PublicAgentState(TypedDict, Generic[ResponseT]):
|
||||||
"""Input / output schema for the agent."""
|
"""Public state schema for the agent.
|
||||||
|
|
||||||
messages: Required[Messages]
|
Just used for typing purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||||
response: NotRequired[ResponseT]
|
response: NotRequired[ResponseT]
|
||||||
|
|
||||||
|
|
||||||
StateT = TypeVar("StateT", bound=AgentState)
|
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
||||||
|
|
||||||
|
|
||||||
class AgentMiddleware(Generic[StateT]):
|
class AgentMiddleware(Generic[StateT, ContextT]):
|
||||||
"""Base middleware class for an agent.
|
"""Base middleware class for an agent.
|
||||||
|
|
||||||
Subclass this and implement any of the defined methods to customize agent behavior
|
Subclass this and implement any of the defined methods to customize agent behavior
|
||||||
@@ -68,12 +103,17 @@ class AgentMiddleware(Generic[StateT]):
|
|||||||
tools: list[BaseTool]
|
tools: list[BaseTool]
|
||||||
"""Additional tools registered by the middleware."""
|
"""Additional tools registered by the middleware."""
|
||||||
|
|
||||||
def before_model(self, state: StateT) -> dict[str, Any] | None:
|
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||||
"""Logic to run before the model is called."""
|
"""Logic to run before the model is called."""
|
||||||
|
|
||||||
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002
|
def modify_model_request(
|
||||||
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
state: StateT, # noqa: ARG002
|
||||||
|
runtime: Runtime[ContextT], # noqa: ARG002
|
||||||
|
) -> ModelRequest:
|
||||||
"""Logic to modify request kwargs before the model is called."""
|
"""Logic to modify request kwargs before the model is called."""
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def after_model(self, state: StateT) -> dict[str, Any] | None:
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||||
"""Logic to run after the model is called."""
|
"""Logic to run after the model is called."""
|
||||||
|
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Any, cast
|
from inspect import signature
|
||||||
|
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||||
@@ -10,19 +11,19 @@ from langchain_core.runnables import Runnable
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph.state import StateGraph
|
from langgraph.graph.state import StateGraph
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
from langgraph.types import Send
|
from langgraph.types import Send
|
||||||
from langgraph.typing import ContextT
|
from langgraph.typing import ContextT
|
||||||
from typing_extensions import TypedDict, TypeVar
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||||
|
|
||||||
from langchain.agents.middleware.types import (
|
from langchain.agents.middleware.types import (
|
||||||
AgentMiddleware,
|
AgentMiddleware,
|
||||||
AgentState,
|
AgentState,
|
||||||
JumpTo,
|
JumpTo,
|
||||||
ModelRequest,
|
ModelRequest,
|
||||||
|
OmitFromSchema,
|
||||||
PublicAgentState,
|
PublicAgentState,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import structured output classes from the old implementation
|
|
||||||
from langchain.agents.structured_output import (
|
from langchain.agents.structured_output import (
|
||||||
MultipleStructuredOutputsError,
|
MultipleStructuredOutputsError,
|
||||||
OutputToolBinding,
|
OutputToolBinding,
|
||||||
@@ -38,26 +39,49 @@ from langchain.chat_models import init_chat_model
|
|||||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||||
|
|
||||||
|
|
||||||
def _merge_state_schemas(schemas: list[type]) -> type:
|
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||||
"""Merge multiple TypedDict schemas into a single schema with all fields."""
|
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
||||||
if not schemas:
|
|
||||||
return AgentState
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schemas: List of schema types to merge
|
||||||
|
schema_name: Name for the generated TypedDict
|
||||||
|
omit_flag: If specified, omit fields with this flag set ('input' or 'output')
|
||||||
|
"""
|
||||||
all_annotations = {}
|
all_annotations = {}
|
||||||
|
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
all_annotations.update(schema.__annotations__)
|
hints = get_type_hints(schema, include_extras=True)
|
||||||
|
|
||||||
return TypedDict("MergedState", all_annotations) # type: ignore[operator]
|
for field_name, field_type in hints.items():
|
||||||
|
should_omit = False
|
||||||
|
|
||||||
|
if omit_flag:
|
||||||
|
# Check for omission in the annotation metadata
|
||||||
|
metadata = _extract_metadata(field_type)
|
||||||
|
for meta in metadata:
|
||||||
|
if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
|
||||||
|
should_omit = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not should_omit:
|
||||||
|
all_annotations[field_name] = field_type
|
||||||
|
|
||||||
|
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
|
||||||
|
|
||||||
|
|
||||||
def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]:
|
def _extract_metadata(type_: type) -> list:
|
||||||
"""Filter state to only include fields defined in the given schema."""
|
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
|
||||||
if not hasattr(schema, "__annotations__"):
|
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
|
||||||
return state
|
if get_origin(type_) in (Required, NotRequired):
|
||||||
|
inner_type = get_args(type_)[0]
|
||||||
|
if get_origin(inner_type) is Annotated:
|
||||||
|
return list(get_args(inner_type)[1:])
|
||||||
|
|
||||||
schema_fields = set(schema.__annotations__.keys())
|
# Handle direct Annotated[...]
|
||||||
return {k: v for k, v in state.items() if k in schema_fields}
|
elif get_origin(type_) is Annotated:
|
||||||
|
return list(get_args(type_)[1:])
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
|
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
|
||||||
@@ -114,7 +138,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
model: str | BaseChatModel,
|
model: str | BaseChatModel,
|
||||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
middleware: Sequence[AgentMiddleware] = (),
|
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
|
||||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||||
context_schema: type[ContextT] | None = None,
|
context_schema: type[ContextT] | None = None,
|
||||||
) -> StateGraph[
|
) -> StateGraph[
|
||||||
@@ -199,16 +223,16 @@ def create_agent( # noqa: PLR0915
|
|||||||
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
||||||
]
|
]
|
||||||
|
|
||||||
# Collect all middleware state schemas and create merged schema
|
state_schemas = {m.state_schema for m in middleware}
|
||||||
merged_state_schema: type[AgentState] = _merge_state_schemas(
|
state_schemas.add(AgentState)
|
||||||
[m.state_schema for m in middleware]
|
|
||||||
)
|
|
||||||
|
|
||||||
# create graph, add nodes
|
# create graph, add nodes
|
||||||
graph = StateGraph(
|
graph: StateGraph[
|
||||||
merged_state_schema,
|
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
||||||
input_schema=PublicAgentState,
|
] = StateGraph(
|
||||||
output_schema=PublicAgentState,
|
state_schema=_resolve_schema(state_schemas, "StateSchema", None),
|
||||||
|
input_schema=_resolve_schema(state_schemas, "InputSchema", "input"),
|
||||||
|
output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"),
|
||||||
context_schema=context_schema,
|
context_schema=context_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,7 +342,14 @@ def create_agent( # noqa: PLR0915
|
|||||||
)
|
)
|
||||||
return request.model.bind(**request.model_settings)
|
return request.model.bind(**request.model_settings)
|
||||||
|
|
||||||
def model_request(state: dict[str, Any]) -> dict[str, Any]:
|
model_request_signatures: list[
|
||||||
|
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
|
||||||
|
] = [
|
||||||
|
("runtime" in signature(m.modify_model_request).parameters, m)
|
||||||
|
for m in middleware_w_modify_model_request
|
||||||
|
]
|
||||||
|
|
||||||
|
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||||
"""Sync model request handler with sequential middleware processing."""
|
"""Sync model request handler with sequential middleware processing."""
|
||||||
request = ModelRequest(
|
request = ModelRequest(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -330,10 +361,11 @@ def create_agent( # noqa: PLR0915
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply modify_model_request middleware in sequence
|
# Apply modify_model_request middleware in sequence
|
||||||
for m in middleware_w_modify_model_request:
|
for use_runtime, m in model_request_signatures:
|
||||||
# Filter state to only include fields defined in this middleware's schema
|
if use_runtime:
|
||||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
m.modify_model_request(request, state, runtime)
|
||||||
request = m.modify_model_request(request, filtered_state)
|
else:
|
||||||
|
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||||
|
|
||||||
# Get the final model and messages
|
# Get the final model and messages
|
||||||
model_ = _get_bound_model(request)
|
model_ = _get_bound_model(request)
|
||||||
@@ -344,7 +376,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
output = model_.invoke(messages)
|
output = model_.invoke(messages)
|
||||||
return _handle_model_output(output)
|
return _handle_model_output(output)
|
||||||
|
|
||||||
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||||
"""Async model request handler with sequential middleware processing."""
|
"""Async model request handler with sequential middleware processing."""
|
||||||
# Start with the base model request
|
# Start with the base model request
|
||||||
request = ModelRequest(
|
request = ModelRequest(
|
||||||
@@ -357,10 +389,11 @@ def create_agent( # noqa: PLR0915
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply modify_model_request middleware in sequence
|
# Apply modify_model_request middleware in sequence
|
||||||
for m in middleware_w_modify_model_request:
|
for use_runtime, m in model_request_signatures:
|
||||||
# Filter state to only include fields defined in this middleware's schema
|
if use_runtime:
|
||||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
m.modify_model_request(request, state, runtime)
|
||||||
request = m.modify_model_request(request, filtered_state)
|
else:
|
||||||
|
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||||
|
|
||||||
# Get the final model and messages
|
# Get the final model and messages
|
||||||
model_ = _get_bound_model(request)
|
model_ = _get_bound_model(request)
|
||||||
|
@@ -1,9 +1,12 @@
|
|||||||
|
from typing_extensions import TypedDict
|
||||||
import pytest
|
import pytest
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from typing_extensions import Annotated, TypedDict
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
@@ -15,17 +18,23 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.types import Command
|
|
||||||
|
|
||||||
from langchain.agents.middleware_agent import create_agent
|
from langchain.agents.middleware_agent import create_agent
|
||||||
from langchain.agents.middleware.human_in_the_loop import (
|
from langchain.agents.middleware.human_in_the_loop import (
|
||||||
HumanInTheLoopMiddleware,
|
HumanInTheLoopMiddleware,
|
||||||
HumanInTheLoopConfig,
|
|
||||||
ActionRequest,
|
ActionRequest,
|
||||||
)
|
)
|
||||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
ModelRequest,
|
||||||
|
AgentState,
|
||||||
|
OmitFromInput,
|
||||||
|
OmitFromOutput,
|
||||||
|
PrivateStateAttr,
|
||||||
|
)
|
||||||
|
from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware
|
||||||
|
|
||||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
@@ -1201,3 +1210,128 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
|
|||||||
assert hasattr(result["response"], "temperature")
|
assert hasattr(result["response"], "temperature")
|
||||||
assert result["response"].temperature == 72.0
|
assert result["response"].temperature == 72.0
|
||||||
assert result["response"].condition == "sunny"
|
assert result["response"].condition == "sunny"
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for DynamicSystemPromptMiddleware
|
||||||
|
def test_dynamic_system_prompt_middleware_basic() -> None:
|
||||||
|
"""Test basic functionality of DynamicSystemPromptMiddleware."""
|
||||||
|
|
||||||
|
def dynamic_system_prompt(state: AgentState) -> str:
|
||||||
|
messages = state.get("messages", [])
|
||||||
|
if messages:
|
||||||
|
return f"You are a helpful assistant. Message count: {len(messages)}"
|
||||||
|
return "You are a helpful assistant. No messages yet."
|
||||||
|
|
||||||
|
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
|
||||||
|
|
||||||
|
# Test with empty state
|
||||||
|
empty_state = {"messages": []}
|
||||||
|
request = ModelRequest(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
system_prompt="Original prompt",
|
||||||
|
messages=[],
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
modified_request = middleware.modify_model_request(request, empty_state, None)
|
||||||
|
assert modified_request.system_prompt == "You are a helpful assistant. No messages yet."
|
||||||
|
|
||||||
|
state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]}
|
||||||
|
modified_request = middleware.modify_model_request(request, state_with_messages, None)
|
||||||
|
assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_system_prompt_middleware_with_context() -> None:
|
||||||
|
"""Test DynamicSystemPromptMiddleware with runtime context."""
|
||||||
|
|
||||||
|
class MockContext(TypedDict):
|
||||||
|
user_role: str
|
||||||
|
|
||||||
|
def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str:
|
||||||
|
base_prompt = "You are a helpful assistant."
|
||||||
|
if runtime and hasattr(runtime, "context"):
|
||||||
|
user_role = runtime.context.get("user_role", "user")
|
||||||
|
return f"{base_prompt} User role: {user_role}"
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
|
||||||
|
|
||||||
|
# Create a mock runtime with context
|
||||||
|
class MockRuntime:
|
||||||
|
def __init__(self, context):
|
||||||
|
self.context = context
|
||||||
|
|
||||||
|
mock_runtime = MockRuntime(context={"user_role": "admin"})
|
||||||
|
|
||||||
|
request = ModelRequest(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
system_prompt="Original prompt",
|
||||||
|
messages=[HumanMessage("Test")],
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = {"messages": [HumanMessage("Test")]}
|
||||||
|
modified_request = middleware.modify_model_request(request, state, mock_runtime)
|
||||||
|
|
||||||
|
assert modified_request.system_prompt == "You are a helpful assistant. User role: admin"
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_private_state_for_custom_middleware() -> None:
|
||||||
|
"""Test public and private state for custom middleware."""
|
||||||
|
|
||||||
|
class CustomState(AgentState):
|
||||||
|
omit_input: Annotated[str, OmitFromInput]
|
||||||
|
omit_output: Annotated[str, OmitFromOutput]
|
||||||
|
private_state: Annotated[str, PrivateStateAttr]
|
||||||
|
|
||||||
|
class CustomMiddleware(AgentMiddleware[CustomState]):
|
||||||
|
state_schema: type[CustomState] = CustomState
|
||||||
|
|
||||||
|
def before_model(self, state: CustomState) -> dict[str, Any]:
|
||||||
|
assert "omit_input" not in state
|
||||||
|
assert "omit_output" in state
|
||||||
|
assert "private_state" not in state
|
||||||
|
return {"omit_input": "test", "omit_output": "test", "private_state": "test"}
|
||||||
|
|
||||||
|
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
|
||||||
|
agent = agent.compile()
|
||||||
|
result = agent.invoke(
|
||||||
|
{
|
||||||
|
"messages": [HumanMessage("Hello")],
|
||||||
|
"omit_input": "test in",
|
||||||
|
"private_state": "test in",
|
||||||
|
"omit_output": "test in",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert "omit_input" in result
|
||||||
|
assert "omit_output" not in result
|
||||||
|
assert "private_state" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_injected_into_middleware() -> None:
|
||||||
|
"""Test that the runtime is injected into the middleware."""
|
||||||
|
|
||||||
|
class CustomMiddleware(AgentMiddleware):
|
||||||
|
def before_model(self, state: AgentState, runtime: Runtime) -> None:
|
||||||
|
assert runtime is not None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def modify_model_request(
|
||||||
|
self, request: ModelRequest, state: AgentState, runtime: Runtime
|
||||||
|
) -> ModelRequest:
|
||||||
|
assert runtime is not None
|
||||||
|
return request
|
||||||
|
|
||||||
|
def after_model(self, state: AgentState, runtime: Runtime) -> None:
|
||||||
|
assert runtime is not None
|
||||||
|
return None
|
||||||
|
|
||||||
|
middleware = CustomMiddleware()
|
||||||
|
|
||||||
|
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
|
||||||
|
agent = agent.compile()
|
||||||
|
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||||
|
Reference in New Issue
Block a user