diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index e554d2aed4e..014989516ae 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -1,6 +1,5 @@ """Middleware plugins for agents.""" -from .dynamic_system_prompt import DynamicSystemPromptMiddleware from .human_in_the_loop import HumanInTheLoopMiddleware from .prompt_caching import AnthropicPromptCachingMiddleware from .summarization import SummarizationMiddleware @@ -11,7 +10,6 @@ __all__ = [ "AgentState", # should move to langchain-anthropic if we decide to keep it "AnthropicPromptCachingMiddleware", - "DynamicSystemPromptMiddleware", "HumanInTheLoopMiddleware", "ModelRequest", "SummarizationMiddleware", diff --git a/libs/langchain_v1/langchain/agents/middleware/dynamic_system_prompt.py b/libs/langchain_v1/langchain/agents/middleware/dynamic_system_prompt.py deleted file mode 100644 index d1f3f8b03b5..00000000000 --- a/libs/langchain_v1/langchain/agents/middleware/dynamic_system_prompt.py +++ /dev/null @@ -1,105 +0,0 @@ -"""Dynamic System Prompt Middleware. - -Allows setting the system prompt dynamically right before each model invocation. -Useful when the prompt depends on the current agent state or per-invocation context. -""" - -from __future__ import annotations - -from inspect import signature -from typing import TYPE_CHECKING, Protocol, TypeAlias, cast - -from langgraph.typing import ContextT - -from langchain.agents.middleware.types import ( - AgentMiddleware, - AgentState, - ModelRequest, -) - -if TYPE_CHECKING: - from langgraph.runtime import Runtime - - -class DynamicSystemPromptWithoutRuntime(Protocol): - """Dynamic system prompt without runtime in call signature.""" - - def __call__(self, state: AgentState) -> str: - """Return the system prompt for the next model call.""" - ... - - -class DynamicSystemPromptWithRuntime(Protocol[ContextT]): - """Dynamic system prompt with runtime in call signature.""" - - def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str: - """Return the system prompt for the next model call.""" - ... - - -DynamicSystemPrompt: TypeAlias = ( - DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT] -) - - -class DynamicSystemPromptMiddleware(AgentMiddleware): - """Dynamic System Prompt Middleware. - - Allows setting the system prompt dynamically right before each model invocation. - Useful when the prompt depends on the current agent state or per-invocation context. - - Example: - ```python - from langchain.agents.middleware import DynamicSystemPromptMiddleware - - - class Context(TypedDict): - user_name: str - - - def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str: - user_name = runtime.context.get("user_name", "n/a") - return ( - f"You are a helpful assistant. Always address the user by their name: {user_name}" - ) - - - middleware = DynamicSystemPromptMiddleware(system_prompt) - ``` - """ - - _accepts_runtime: bool - - def __init__( - self, - dynamic_system_prompt: DynamicSystemPrompt[ContextT], - ) -> None: - """Initialize the dynamic system prompt middleware. - - Args: - dynamic_system_prompt: Function that receives the current agent state - and optionally runtime with context, and returns the system prompt for - the next model call. Returns a string. - """ - super().__init__() - self.dynamic_system_prompt = dynamic_system_prompt - self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters - - def modify_model_request( - self, - request: ModelRequest, - state: AgentState, - runtime: Runtime[ContextT], - ) -> ModelRequest: - """Modify the model request to include the dynamic system prompt.""" - if self._accepts_runtime: - system_prompt = cast( - "DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt - )(state, runtime) - else: - system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)( - state - ) - - request.system_prompt = system_prompt - return request diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index 5a6c7ecf56c..cd55681a421 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -3,7 +3,20 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast +from inspect import signature +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + ClassVar, + Generic, + Literal, + Protocol, + TypeAlias, + TypeGuard, + cast, + overload, +) # needed as top level import for pydantic schema generation on AgentState from langchain_core.messages import AnyMessage # noqa: TC002 @@ -14,9 +27,12 @@ from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict, TypeVar if TYPE_CHECKING: + from collections.abc import Callable + from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.runtime import Runtime + from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat @@ -88,6 +104,7 @@ class PublicAgentState(TypedDict, Generic[ResponseT]): StateT = TypeVar("StateT", bound=AgentState, default=AgentState) +StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True) class AgentMiddleware(Generic[StateT, ContextT]): @@ -103,6 +120,12 @@ class AgentMiddleware(Generic[StateT, ContextT]): tools: list[BaseTool] """Additional tools registered by the middleware.""" + before_model_jump_to: ClassVar[list[JumpTo]] = [] + """Valid jump destinations for before_model hook. Used to establish conditional edges.""" + + after_model_jump_to: ClassVar[list[JumpTo]] = [] + """Valid jump destinations for after_model hook. Used to establish conditional edges.""" + def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run before the model is called.""" @@ -117,3 +140,404 @@ class AgentMiddleware(Generic[StateT, ContextT]): def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Logic to run after the model is called.""" + + +class _CallableWithState(Protocol[StateT_contra]): + """Callable with AgentState as argument.""" + + def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None: + """Perform some logic with the state.""" + ... + + +class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): + """Callable with AgentState and Runtime as arguments.""" + + def __call__( + self, state: StateT_contra, runtime: Runtime[ContextT] + ) -> dict[str, Any] | Command | None: + """Perform some logic with the state and runtime.""" + ... + + +class _CallableWithModelRequestAndState(Protocol[StateT_contra]): + """Callable with ModelRequest and AgentState as arguments.""" + + def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest: + """Perform some logic with the model request and state.""" + ... + + +class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]): + """Callable with ModelRequest, AgentState, and Runtime as arguments.""" + + def __call__( + self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT] + ) -> ModelRequest: + """Perform some logic with the model request, state, and runtime.""" + ... + + +_NodeSignature: TypeAlias = ( + _CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT] +) +_ModelRequestSignature: TypeAlias = ( + _CallableWithModelRequestAndState[StateT] + | _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT] +) + + +def is_callable_with_runtime( + func: _NodeSignature[StateT, ContextT], +) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]: + return "runtime" in signature(func).parameters + + +def is_callable_with_runtime_and_request( + func: _ModelRequestSignature[StateT, ContextT], +) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]: + return "runtime" in signature(func).parameters + + +@overload +def before_model( + func: _NodeSignature[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def before_model( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ... + + +def before_model( + func: _NodeSignature[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the before_model hook. + + Args: + func: The function to be decorated. Can accept either: + - `state: StateT` - Just the agent state + - `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "__end__" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided directly) or a decorator function + that can be applied to a function its wrapping. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage with state only: + ```python + @before_model + def log_before_model(state: AgentState) -> None: + print(f"About to call model with {len(state['messages'])} messages") + ``` + + Advanced usage with runtime and conditional jumping: + ```python + @before_model(jump_to=["__end__"]) + def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + if some_condition(state): + return {"jump_to": "__end__"} + return None + ``` + + With custom state schema: + ```python + @before_model( + state_schema=MyCustomState, + ) + def custom_before_model(state: MyCustomState) -> dict[str, Any]: + return {"custom_field": "updated_value"} + ``` + """ + + def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]: + if is_callable_with_runtime(func): + + def wrapped_with_runtime( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) + + wrapped = wrapped_with_runtime + else: + + def wrapped_without_runtime( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + ) -> dict[str, Any] | Command | None: + return func(state) # type: ignore[call-arg] + + wrapped = wrapped_without_runtime # type: ignore[assignment] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "before_model_jump_to": jump_to or [], + "before_model": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + +@overload +def modify_model_request( + func: _ModelRequestSignature[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def modify_model_request( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + name: str | None = None, +) -> Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ... + + +def modify_model_request( + func: _ModelRequestSignature[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + name: str | None = None, +) -> ( + Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + r"""Decorator used to dynamically create a middleware with the modify_model_request hook. + + Args: + func: The function to be decorated. Can accept either: + - `request: ModelRequest, state: StateT` - Model request and agent state + - `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` - + Model request, state, and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided) or a decorator function + that can be applied to a function. + + The decorated function should return: + - `ModelRequest` - The modified model request to be sent to the language model + + Examples: + Basic usage to modify system prompt: + ```python + @modify_model_request + def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest: + if request.system_prompt: + request.system_prompt += "\n\nAdditional context: ..." + else: + request.system_prompt = "Additional context: ..." + return request + ``` + + Advanced usage with runtime and custom model settings: + ```python + @modify_model_request + def dynamic_model_settings( + request: ModelRequest, state: AgentState, runtime: Runtime + ) -> ModelRequest: + # Use a different model based on user subscription tier + if runtime.context.get("subscription_tier") == "premium": + request.model = "gpt-4o" + else: + request.model = "gpt-4o-mini" + + return request + ``` + """ + + def decorator( + func: _ModelRequestSignature[StateT, ContextT], + ) -> AgentMiddleware[StateT, ContextT]: + if is_callable_with_runtime_and_request(func): + + def wrapped_with_runtime( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + request: ModelRequest, + state: StateT, + runtime: Runtime[ContextT], + ) -> ModelRequest: + return func(request, state, runtime) + + wrapped = wrapped_with_runtime + else: + + def wrapped_without_runtime( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + request: ModelRequest, + state: StateT, + ) -> ModelRequest: + return func(request, state) # type: ignore[call-arg] + + wrapped = wrapped_without_runtime # type: ignore[assignment] + + # Use function name as default if no name provided + middleware_name = name or cast( + "str", getattr(func, "__name__", "ModifyModelRequestMiddleware") + ) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "modify_model_request": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator + + +@overload +def after_model( + func: _NodeSignature[StateT, ContextT], +) -> AgentMiddleware[StateT, ContextT]: ... + + +@overload +def after_model( + func: None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ... + + +def after_model( + func: _NodeSignature[StateT, ContextT] | None = None, + *, + state_schema: type[StateT] | None = None, + tools: list[BaseTool] | None = None, + jump_to: list[JumpTo] | None = None, + name: str | None = None, +) -> ( + Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]] + | AgentMiddleware[StateT, ContextT] +): + """Decorator used to dynamically create a middleware with the after_model hook. + + Args: + func: The function to be decorated. Can accept either: + - `state: StateT` - Just the agent state (includes model response) + - `state: StateT, runtime: Runtime[ContextT]` - State and runtime context + state_schema: Optional custom state schema type. If not provided, uses the default + AgentState schema. + tools: Optional list of additional tools to register with this middleware. + jump_to: Optional list of valid jump destinations for conditional edges. + Valid values are: "tools", "model", "__end__" + name: Optional name for the generated middleware class. If not provided, + uses the decorated function's name. + + Returns: + Either an AgentMiddleware instance (if func is provided) or a decorator function + that can be applied to a function. + + The decorated function should return: + - `dict[str, Any]` - State updates to merge into the agent state + - `Command` - A command to control flow (e.g., jump to different node) + - `None` - No state updates or flow control + + Examples: + Basic usage for logging model responses: + ```python + @after_model + def log_latest_message(state: AgentState) -> None: + print(state["messages"][-1].content) + ``` + + With custom state schema: + ```python + @after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware") + def custom_after_model(state: MyCustomState) -> dict[str, Any]: + return {"custom_field": "updated_after_model"} + ``` + """ + + def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]: + if is_callable_with_runtime(func): + + def wrapped_with_runtime( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + runtime: Runtime[ContextT], + ) -> dict[str, Any] | Command | None: + return func(state, runtime) + + wrapped = wrapped_with_runtime + else: + + def wrapped_without_runtime( + self: AgentMiddleware[StateT, ContextT], # noqa: ARG001 + state: StateT, + ) -> dict[str, Any] | Command | None: + return func(state) # type: ignore[call-arg] + + wrapped = wrapped_without_runtime # type: ignore[assignment] + + # Use function name as default if no name provided + middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware")) + + return type( + middleware_name, + (AgentMiddleware,), + { + "state_schema": state_schema or AgentState, + "tools": tools or [], + "after_model_jump_to": jump_to or [], + "after_model": wrapped, + }, + )() + + if func is not None: + return decorator(func) + return decorator diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 718e4a692e6..0f319931069 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -464,7 +464,7 @@ def create_agent( # noqa: PLR0915 f"{middleware_w_after[0].__class__.__name__}.after_model", END, first_node, - tools_available=tool_node is not None, + jump_to=middleware_w_after[0].after_model_jump_to, ) # Add middleware edges (same as before) @@ -475,7 +475,7 @@ def create_agent( # noqa: PLR0915 f"{m1.__class__.__name__}.before_model", f"{m2.__class__.__name__}.before_model", first_node, - tools_available=tool_node is not None, + jump_to=m1.before_model_jump_to, ) # Go directly to model_request after the last before_model _add_middleware_edge( @@ -483,7 +483,7 @@ def create_agent( # noqa: PLR0915 f"{middleware_w_before[-1].__class__.__name__}.before_model", "model_request", first_node, - tools_available=tool_node is not None, + jump_to=middleware_w_before[-1].before_model_jump_to, ) if middleware_w_after: @@ -496,7 +496,7 @@ def create_agent( # noqa: PLR0915 f"{m1.__class__.__name__}.after_model", f"{m2.__class__.__name__}.after_model", first_node, - tools_available=tool_node is not None, + jump_to=m1.after_model_jump_to, ) return graph @@ -584,7 +584,7 @@ def _add_middleware_edge( name: str, default_destination: str, model_destination: str, - tools_available: bool, # noqa: FBT001 + jump_to: list[JumpTo] | None, ) -> None: """Add an edge to the graph for a middleware node. @@ -594,18 +594,23 @@ def _add_middleware_edge( name: The name of the middleware node. default_destination: The default destination for the edge. model_destination: The destination for the edge to the model. - tools_available: Whether tools are available for the edge to potentially route to. + jump_to: The conditionally jumpable destinations for the edge. """ + if jump_to: - def jump_edge(state: AgentState) -> str: - return _resolve_jump(state.get("jump_to"), model_destination) or default_destination + def jump_edge(state: AgentState) -> str: + return _resolve_jump(state.get("jump_to"), model_destination) or default_destination - destinations = [default_destination] - if default_destination != END: - destinations.append(END) - if tools_available: - destinations.append("tools") - if name != model_destination: - destinations.append(model_destination) + destinations = [default_destination] - graph.add_conditional_edges(name, jump_edge, destinations) + if "__end__" in jump_to: + destinations.append(END) + if "tools" in jump_to: + destinations.append("tools") + if "model" in jump_to and name != model_destination: + destinations.append(model_destination) + + graph.add_conditional_edges(name, jump_edge, destinations) + + else: + graph.add_edge(name, default_destination) diff --git a/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr index f07cc34c1af..2581a1ea074 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr +++ b/libs/langchain_v1/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr @@ -30,8 +30,7 @@ model_request(model_request) NoopOne\2ebefore_model(NoopOne.before_model) __end__([
__end__
]):::last - NoopOne\2ebefore_model -.-> __end__; - NoopOne\2ebefore_model -.-> model_request; + NoopOne\2ebefore_model --> model_request; __start__ --> NoopOne\2ebefore_model; model_request --> __end__; classDef default fill:#f2f0ff,line-height:1.2 @@ -53,12 +52,10 @@ NoopTen\2ebefore_model(NoopTen.before_model) NoopTen\2eafter_model(NoopTen.after_model) __end__([__end__
]):::last - NoopTen\2eafter_model -.-> NoopTen\2ebefore_model; - NoopTen\2eafter_model -.-> __end__; - NoopTen\2ebefore_model -.-> __end__; - NoopTen\2ebefore_model -.-> model_request; + NoopTen\2ebefore_model --> model_request; __start__ --> NoopTen\2ebefore_model; model_request --> NoopTen\2eafter_model; + NoopTen\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -80,18 +77,12 @@ NoopEleven\2ebefore_model(NoopEleven.before_model) NoopEleven\2eafter_model(NoopEleven.after_model) __end__([__end__
]):::last - NoopEleven\2eafter_model -.-> NoopTen\2eafter_model; - NoopEleven\2eafter_model -.-> NoopTen\2ebefore_model; - NoopEleven\2eafter_model -.-> __end__; - NoopEleven\2ebefore_model -.-> NoopTen\2ebefore_model; - NoopEleven\2ebefore_model -.-> __end__; - NoopEleven\2ebefore_model -.-> model_request; - NoopTen\2eafter_model -.-> NoopTen\2ebefore_model; - NoopTen\2eafter_model -.-> __end__; - NoopTen\2ebefore_model -.-> NoopEleven\2ebefore_model; - NoopTen\2ebefore_model -.-> __end__; + NoopEleven\2eafter_model --> NoopTen\2eafter_model; + NoopEleven\2ebefore_model --> model_request; + NoopTen\2ebefore_model --> NoopEleven\2ebefore_model; __start__ --> NoopTen\2ebefore_model; model_request --> NoopEleven\2eafter_model; + NoopTen\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -111,11 +102,8 @@ NoopOne\2ebefore_model(NoopOne.before_model) NoopTwo\2ebefore_model(NoopTwo.before_model) __end__([__end__
]):::last - NoopOne\2ebefore_model -.-> NoopTwo\2ebefore_model; - NoopOne\2ebefore_model -.-> __end__; - NoopTwo\2ebefore_model -.-> NoopOne\2ebefore_model; - NoopTwo\2ebefore_model -.-> __end__; - NoopTwo\2ebefore_model -.-> model_request; + NoopOne\2ebefore_model --> NoopTwo\2ebefore_model; + NoopTwo\2ebefore_model --> model_request; __start__ --> NoopOne\2ebefore_model; model_request --> __end__; classDef default fill:#f2f0ff,line-height:1.2 @@ -138,14 +126,9 @@ NoopTwo\2ebefore_model(NoopTwo.before_model) NoopThree\2ebefore_model(NoopThree.before_model) __end__([__end__
]):::last - NoopOne\2ebefore_model -.-> NoopTwo\2ebefore_model; - NoopOne\2ebefore_model -.-> __end__; - NoopThree\2ebefore_model -.-> NoopOne\2ebefore_model; - NoopThree\2ebefore_model -.-> __end__; - NoopThree\2ebefore_model -.-> model_request; - NoopTwo\2ebefore_model -.-> NoopOne\2ebefore_model; - NoopTwo\2ebefore_model -.-> NoopThree\2ebefore_model; - NoopTwo\2ebefore_model -.-> __end__; + NoopOne\2ebefore_model --> NoopTwo\2ebefore_model; + NoopThree\2ebefore_model --> model_request; + NoopTwo\2ebefore_model --> NoopThree\2ebefore_model; __start__ --> NoopOne\2ebefore_model; model_request --> __end__; classDef default fill:#f2f0ff,line-height:1.2 @@ -166,10 +149,9 @@ model_request(model_request) NoopFour\2eafter_model(NoopFour.after_model) __end__([__end__
]):::last - NoopFour\2eafter_model -.-> __end__; - NoopFour\2eafter_model -.-> model_request; __start__ --> model_request; model_request --> NoopFour\2eafter_model; + NoopFour\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -189,13 +171,10 @@ NoopFour\2eafter_model(NoopFour.after_model) NoopFive\2eafter_model(NoopFive.after_model) __end__([__end__
]):::last - NoopFive\2eafter_model -.-> NoopFour\2eafter_model; - NoopFive\2eafter_model -.-> __end__; - NoopFive\2eafter_model -.-> model_request; - NoopFour\2eafter_model -.-> __end__; - NoopFour\2eafter_model -.-> model_request; + NoopFive\2eafter_model --> NoopFour\2eafter_model; __start__ --> model_request; model_request --> NoopFive\2eafter_model; + NoopFour\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -216,16 +195,11 @@ NoopFive\2eafter_model(NoopFive.after_model) NoopSix\2eafter_model(NoopSix.after_model) __end__([__end__
]):::last - NoopFive\2eafter_model -.-> NoopFour\2eafter_model; - NoopFive\2eafter_model -.-> __end__; - NoopFive\2eafter_model -.-> model_request; - NoopFour\2eafter_model -.-> __end__; - NoopFour\2eafter_model -.-> model_request; - NoopSix\2eafter_model -.-> NoopFive\2eafter_model; - NoopSix\2eafter_model -.-> __end__; - NoopSix\2eafter_model -.-> model_request; + NoopFive\2eafter_model --> NoopFour\2eafter_model; + NoopSix\2eafter_model --> NoopFive\2eafter_model; __start__ --> model_request; model_request --> NoopSix\2eafter_model; + NoopFour\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -245,12 +219,10 @@ NoopSeven\2ebefore_model(NoopSeven.before_model) NoopSeven\2eafter_model(NoopSeven.after_model) __end__([__end__
]):::last - NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopSeven\2eafter_model -.-> __end__; - NoopSeven\2ebefore_model -.-> __end__; - NoopSeven\2ebefore_model -.-> model_request; + NoopSeven\2ebefore_model --> model_request; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopSeven\2eafter_model; + NoopSeven\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -272,18 +244,12 @@ NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last - NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; - NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopEight\2eafter_model -.-> __end__; - NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; - NoopEight\2ebefore_model -.-> __end__; - NoopEight\2ebefore_model -.-> model_request; - NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopSeven\2eafter_model -.-> __end__; - NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; - NoopSeven\2ebefore_model -.-> __end__; + NoopEight\2eafter_model --> NoopSeven\2eafter_model; + NoopEight\2ebefore_model --> model_request; + NoopSeven\2ebefore_model --> NoopEight\2ebefore_model; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopEight\2eafter_model; + NoopSeven\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -307,24 +273,14 @@ NoopNine\2ebefore_model(NoopNine.before_model) NoopNine\2eafter_model(NoopNine.after_model) __end__([__end__
]):::last - NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; - NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopEight\2eafter_model -.-> __end__; - NoopEight\2ebefore_model -.-> NoopNine\2ebefore_model; - NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; - NoopEight\2ebefore_model -.-> __end__; - NoopNine\2eafter_model -.-> NoopEight\2eafter_model; - NoopNine\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopNine\2eafter_model -.-> __end__; - NoopNine\2ebefore_model -.-> NoopSeven\2ebefore_model; - NoopNine\2ebefore_model -.-> __end__; - NoopNine\2ebefore_model -.-> model_request; - NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopSeven\2eafter_model -.-> __end__; - NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; - NoopSeven\2ebefore_model -.-> __end__; + NoopEight\2eafter_model --> NoopSeven\2eafter_model; + NoopEight\2ebefore_model --> NoopNine\2ebefore_model; + NoopNine\2eafter_model --> NoopEight\2eafter_model; + NoopNine\2ebefore_model --> model_request; + NoopSeven\2ebefore_model --> NoopEight\2ebefore_model; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopNine\2eafter_model; + NoopSeven\2eafter_model --> __end__; classDef default fill:#f2f0ff,line-height:1.2 classDef first fill-opacity:0 classDef last fill:#bfb6fc @@ -347,20 +303,13 @@ NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last - NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; - NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopEight\2eafter_model -.-> __end__; - NoopEight\2eafter_model -.-> tools; - NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; + NoopEight\2eafter_model --> NoopSeven\2eafter_model; NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> model_request; - NoopEight\2ebefore_model -.-> tools; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> tools; - NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; - NoopSeven\2ebefore_model -.-> __end__; - NoopSeven\2ebefore_model -.-> tools; + NoopSeven\2ebefore_model --> NoopEight\2ebefore_model; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopEight\2eafter_model; tools -.-> NoopSeven\2ebefore_model; @@ -387,20 +336,13 @@ NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last - NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; - NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopEight\2eafter_model -.-> __end__; - NoopEight\2eafter_model -.-> tools; - NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; + NoopEight\2eafter_model --> NoopSeven\2eafter_model; NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> model_request; - NoopEight\2ebefore_model -.-> tools; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> tools; - NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; - NoopSeven\2ebefore_model -.-> __end__; - NoopSeven\2ebefore_model -.-> tools; + NoopSeven\2ebefore_model --> NoopEight\2ebefore_model; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopEight\2eafter_model; tools -.-> NoopSeven\2ebefore_model; @@ -427,20 +369,13 @@ NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last - NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; - NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopEight\2eafter_model -.-> __end__; - NoopEight\2eafter_model -.-> tools; - NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; + NoopEight\2eafter_model --> NoopSeven\2eafter_model; NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> model_request; - NoopEight\2ebefore_model -.-> tools; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> tools; - NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; - NoopSeven\2ebefore_model -.-> __end__; - NoopSeven\2ebefore_model -.-> tools; + NoopSeven\2ebefore_model --> NoopEight\2ebefore_model; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopEight\2eafter_model; tools -.-> NoopSeven\2ebefore_model; @@ -467,20 +402,13 @@ NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last - NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; - NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopEight\2eafter_model -.-> __end__; - NoopEight\2eafter_model -.-> tools; - NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; + NoopEight\2eafter_model --> NoopSeven\2eafter_model; NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> model_request; - NoopEight\2ebefore_model -.-> tools; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> tools; - NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; - NoopSeven\2ebefore_model -.-> __end__; - NoopSeven\2ebefore_model -.-> tools; + NoopSeven\2ebefore_model --> NoopEight\2ebefore_model; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopEight\2eafter_model; tools -.-> NoopSeven\2ebefore_model; @@ -507,20 +435,13 @@ NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2eafter_model(NoopEight.after_model) __end__([__end__
]):::last - NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; - NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; - NoopEight\2eafter_model -.-> __end__; - NoopEight\2eafter_model -.-> tools; - NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; + NoopEight\2eafter_model --> NoopSeven\2eafter_model; NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> model_request; - NoopEight\2ebefore_model -.-> tools; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> tools; - NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; - NoopSeven\2ebefore_model -.-> __end__; - NoopSeven\2ebefore_model -.-> tools; + NoopSeven\2ebefore_model --> NoopEight\2ebefore_model; __start__ --> NoopSeven\2ebefore_model; model_request --> NoopEight\2eafter_model; tools -.-> NoopSeven\2ebefore_model; diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index f353ba8c12a..f2c42bb3467 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -34,7 +34,6 @@ from langchain.agents.middleware.types import ( OmitFromOutput, PrivateStateAttr, ) -from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.memory import InMemorySaver @@ -332,6 +331,8 @@ def test_create_agent_jump( calls.append("NoopSeven.after_model") class NoopEight(AgentMiddleware): + before_model_jump_to = [END] + def before_model(self, state) -> dict[str, Any]: calls.append("NoopEight.before_model") return {"jump_to": END} @@ -1221,74 +1222,6 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls(): assert result["response"].condition == "sunny" -# Tests for DynamicSystemPromptMiddleware -def test_dynamic_system_prompt_middleware_basic() -> None: - """Test basic functionality of DynamicSystemPromptMiddleware.""" - - def dynamic_system_prompt(state: AgentState) -> str: - messages = state.get("messages", []) - if messages: - return f"You are a helpful assistant. Message count: {len(messages)}" - return "You are a helpful assistant. No messages yet." - - middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt) - - # Test with empty state - empty_state = {"messages": []} - request = ModelRequest( - model=FakeToolCallingModel(), - system_prompt="Original prompt", - messages=[], - tool_choice=None, - tools=[], - response_format=None, - ) - - modified_request = middleware.modify_model_request(request, empty_state, None) - assert modified_request.system_prompt == "You are a helpful assistant. No messages yet." - - state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]} - modified_request = middleware.modify_model_request(request, state_with_messages, None) - assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2" - - -def test_dynamic_system_prompt_middleware_with_context() -> None: - """Test DynamicSystemPromptMiddleware with runtime context.""" - - class MockContext(TypedDict): - user_role: str - - def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str: - base_prompt = "You are a helpful assistant." - if runtime and hasattr(runtime, "context"): - user_role = runtime.context.get("user_role", "user") - return f"{base_prompt} User role: {user_role}" - return base_prompt - - middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt) - - # Create a mock runtime with context - class MockRuntime: - def __init__(self, context): - self.context = context - - mock_runtime = MockRuntime(context={"user_role": "admin"}) - - request = ModelRequest( - model=FakeToolCallingModel(), - system_prompt="Original prompt", - messages=[HumanMessage("Test")], - tool_choice=None, - tools=[], - response_format=None, - ) - - state = {"messages": [HumanMessage("Test")]} - modified_request = middleware.modify_model_request(request, state, mock_runtime) - - assert modified_request.system_prompt == "You are a helpful assistant. User role: admin" - - def test_public_private_state_for_custom_middleware() -> None: """Test public and private state for custom middleware.""" diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_decorators.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_decorators.py new file mode 100644 index 00000000000..a0e97731d25 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_decorators.py @@ -0,0 +1,152 @@ +"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request.""" + +from typing import Any +from typing_extensions import NotRequired + +from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.tools import tool +from langgraph.types import Command + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ModelRequest, + before_model, + after_model, + modify_model_request, +) +from langchain.agents.middleware_agent import create_agent +from .model import FakeToolCallingModel + + +class CustomState(AgentState): + """Custom state schema for testing.""" + + custom_field: NotRequired[str] + + +@tool +def test_tool(input: str) -> str: + """A test tool for middleware testing.""" + return f"Tool result: {input}" + + +def test_before_model_decorator() -> None: + """Test before_model decorator with all configuration options.""" + + @before_model( + state_schema=CustomState, tools=[test_tool], jump_to=["__end__"], name="CustomBeforeModel" + ) + def custom_before_model(state: CustomState) -> dict[str, Any]: + return {"jump_to": "__end__"} + + assert isinstance(custom_before_model, AgentMiddleware) + assert custom_before_model.state_schema == CustomState + assert custom_before_model.tools == [test_tool] + assert custom_before_model.before_model_jump_to == ["__end__"] + assert custom_before_model.__class__.__name__ == "CustomBeforeModel" + + result = custom_before_model.before_model({"messages": [HumanMessage("Hello")]}) + assert result == {"jump_to": "__end__"} + + +def test_after_model_decorator() -> None: + """Test after_model decorator with all configuration options.""" + + @after_model( + state_schema=CustomState, + tools=[test_tool], + jump_to=["model", "__end__"], + name="CustomAfterModel", + ) + def custom_after_model(state: CustomState) -> dict[str, Any]: + return {"jump_to": "model"} + + # Verify all options were applied + assert isinstance(custom_after_model, AgentMiddleware) + assert custom_after_model.state_schema == CustomState + assert custom_after_model.tools == [test_tool] + assert custom_after_model.after_model_jump_to == ["model", "__end__"] + assert custom_after_model.__class__.__name__ == "CustomAfterModel" + + # Verify it works + result = custom_after_model.after_model({"messages": [HumanMessage("Hello"), AIMessage("Hi!")]}) + assert result == {"jump_to": "model"} + + +def test_modify_model_request_decorator() -> None: + """Test modify_model_request decorator with all configuration options.""" + + @modify_model_request(state_schema=CustomState, tools=[test_tool], name="CustomModifyRequest") + def custom_modify_request(request: ModelRequest, state: CustomState) -> ModelRequest: + request.system_prompt = "Modified" + return request + + # Verify all options were applied + assert isinstance(custom_modify_request, AgentMiddleware) + assert custom_modify_request.state_schema == CustomState + assert custom_modify_request.tools == [test_tool] + assert custom_modify_request.__class__.__name__ == "CustomModifyRequest" + + # Verify it works + original_request = ModelRequest( + model="test-model", + system_prompt="Original", + messages=[HumanMessage("Hello")], + tool_choice=None, + tools=[], + response_format=None, + ) + result = custom_modify_request.modify_model_request( + original_request, {"messages": [HumanMessage("Hello")]} + ) + assert result.system_prompt == "Modified" + + +def test_all_decorators_integration() -> None: + """Test all three decorators working together in an agent.""" + call_order = [] + + @before_model + def track_before(state: AgentState) -> None: + call_order.append("before") + return None + + @modify_model_request + def track_modify(request: ModelRequest, state: AgentState) -> ModelRequest: + call_order.append("modify") + return request + + @after_model + def track_after(state: AgentState) -> None: + call_order.append("after") + return None + + agent = create_agent( + model=FakeToolCallingModel(), middleware=[track_before, track_modify, track_after] + ) + agent = agent.compile() + agent.invoke({"messages": [HumanMessage("Hello")]}) + + assert call_order == ["before", "modify", "after"] + + +def test_decorators_use_function_names_as_default() -> None: + """Test that decorators use function names as default middleware names.""" + + @before_model + def my_before_hook(state: AgentState) -> None: + return None + + @modify_model_request + def my_modify_hook(request: ModelRequest, state: AgentState) -> ModelRequest: + return request + + @after_model + def my_after_hook(state: AgentState) -> None: + return None + + # Verify that function names are used as middleware class names + assert my_before_hook.__class__.__name__ == "my_before_hook" + assert my_modify_hook.__class__.__name__ == "my_modify_hook" + assert my_after_hook.__class__.__name__ == "my_after_hook"