feat(langchain): new decorator pattern for dynamically generated middleware (#33053)

# Main Changes

1. Adding decorator utilities for dynamically defining middleware with
single hook functions (see an example below for dynamic system prompt)
2. Adding better conditional edge drawing with jump configuration
attached to middleware. Can be registered w/ the decorator new
decorator!

## Decorator Utilities

```py
from langchain.agents.middleware_agent import create_agent, AgentState, ModelRequest
from langchain.agents.middleware.types import modify_model_request
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import InMemorySaver


@modify_model_request
def modify_system_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
    request.system_prompt = (
        "You are a helpful assistant."
        f"Please record the number of previous messages in your response: {len(state['messages'])}"
    )
    return request

agent = create_agent(
    model="openai:gpt-4o-mini", 
    middleware=[modify_system_prompt]
).compile(checkpointer=InMemorySaver())
```

## Visualization and Routing improvements

We now require that middlewares define the valid jumps for each hook.

If using the new decorator syntax, this can be done with:

```py
@before_model(jump_to=["__end__"])
@after_model(jump_to=["tools", "__end__"])
```

If using the subclassing syntax, you can use these two class vars:

```py
class MyMiddlewareAgentMiddleware):
    before_model_jump_to = ["__end__"]
    after_model_jump_to = ["tools", "__end__"]
```

Open for debate if we want to bundle these in a single jump map / config
for a middleware. Easy to migrate later if we decide to add more hooks.

We will need to **really clearly document** that these must be
explicitly set in order to enable conditional edges.

Notice for the below case, `Middleware2` does actually enable jumps.

<table>
  <thead>
    <tr>
      <th>Before (broken), adding conditional edges unconditionally</th>
      <th>After (fixed), adding conditional edges sparingly</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>
<img width="619" height="508" alt="Screenshot 2025-09-23 at 10 23 23 AM"
src="https://github.com/user-attachments/assets/bba2d098-a839-4335-8e8c-b50dd8090959"
/>
      </td>
      <td>
<img width="469" height="490" alt="Screenshot 2025-09-23 at 10 23 13 AM"
src="https://github.com/user-attachments/assets/717abf0b-fc73-4d5f-9313-b81247d8fe26"
/>
      </td>
    </tr>
  </tbody>
</table>

<details>
<summary>Snippet for the above</summary>

```py
from typing import Any
from langchain.agents.tool_node import InjectedState
from langgraph.runtime import Runtime
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.middleware_agent import create_agent
from langchain_core.tools import tool
from typing import Annotated
from langchain_core.messages import HumanMessage
from typing_extensions import NotRequired

@tool
def simple_tool(input: str) -> str:
    """A simple tool."""
    return "successful tool call"


class Middleware1(AgentMiddleware):
    """Custom middleware that adds a simple tool."""

    tools = [simple_tool]

    def before_model(self, state: AgentState, runtime: Runtime) -> None:
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> None:
        return None

class Middleware2(AgentMiddleware):

    before_model_jump_to = ["tools", "__end__"]

    def before_model(self, state: AgentState, runtime: Runtime) -> None:
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> None:
        return None

class Middleware3(AgentMiddleware):

    def before_model(self, state: AgentState, runtime: Runtime) -> None:
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> None:
        return None

builder = create_agent(
    model="openai:gpt-4o-mini",
    middleware=[Middleware1(), Middleware2(), Middleware3()],
    system_prompt="You are a helpful assistant.",
)
agent = builder.compile()
```

</details>

## More Examples

### Guardrails `after_model`

<img width="379" height="335" alt="Screenshot 2025-09-23 at 10 40 09 AM"
src="https://github.com/user-attachments/assets/45bac7dd-398e-45d1-ae58-6ecfa27dfc87"
/>

<details>
<summary>Code</summary>

```py
from langchain.agents.middleware_agent import create_agent, AgentState, ModelRequest
from langchain.agents.middleware.types import after_model
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import InMemorySaver
from typing import cast, Any

@after_model(jump_to=["model", "__end__"])
def after_model_hook(state: AgentState) -> dict[str, Any]:
    """Check the last AI message for safety violations."""
    last_message_content = cast(AIMessage, state["messages"][-1]).content.lower()
    print(last_message_content)

    unsafe_keywords = ["pineapple"]
    if any(keyword in last_message_content for keyword in unsafe_keywords):

        # Jump back to model to regenerate response
        return {"jump_to": "model", "messages": [HumanMessage("Please regenerate your response, and don't talk about pineapples. You can talk about apples instead.")]}

    return {"jump_to": "__end__"}

# Create agent with guardrails middleware
agent = create_agent(
    model="openai:gpt-4o-mini",
    middleware=[after_model_hook],
    system_prompt="Keep your responses to one sentence please!"
).compile()

# Test with potentially unsafe input
result = agent.invoke(
    {"messages": [HumanMessage("Tell me something about pineapples")]},
)

for msg in result["messages"]:
    print(msg.pretty_print())

"""
================================ Human Message =================================

Tell me something about pineapples
None
================================== Ai Message ==================================

Pineapples are tropical fruits known for their sweet, tangy flavor and distinctive spiky exterior.
None
================================ Human Message =================================

Please regenerate your response, and don't talk about pineapples. You can talk about apples instead.
None
================================== Ai Message ==================================

Apples are popular fruits that come in various varieties, known for their crisp texture and sweetness, and are often used in cooking and baking.
None
"""
```

</details>
This commit is contained in:
Sydney Runkle
2025-09-23 13:25:55 -04:00
committed by GitHub
parent 2c95586f2a
commit 89079ad411
7 changed files with 640 additions and 312 deletions

View File

@@ -1,6 +1,5 @@
"""Middleware plugins for agents.""" """Middleware plugins for agents."""
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
from .human_in_the_loop import HumanInTheLoopMiddleware from .human_in_the_loop import HumanInTheLoopMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware from .summarization import SummarizationMiddleware
@@ -11,7 +10,6 @@ __all__ = [
"AgentState", "AgentState",
# should move to langchain-anthropic if we decide to keep it # should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware", "AnthropicPromptCachingMiddleware",
"DynamicSystemPromptMiddleware",
"HumanInTheLoopMiddleware", "HumanInTheLoopMiddleware",
"ModelRequest", "ModelRequest",
"SummarizationMiddleware", "SummarizationMiddleware",

View File

@@ -1,105 +0,0 @@
"""Dynamic System Prompt Middleware.
Allows setting the system prompt dynamically right before each model invocation.
Useful when the prompt depends on the current agent state or per-invocation context.
"""
from __future__ import annotations
from inspect import signature
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
from langgraph.typing import ContextT
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
)
if TYPE_CHECKING:
from langgraph.runtime import Runtime
class DynamicSystemPromptWithoutRuntime(Protocol):
"""Dynamic system prompt without runtime in call signature."""
def __call__(self, state: AgentState) -> str:
"""Return the system prompt for the next model call."""
...
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
"""Dynamic system prompt with runtime in call signature."""
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
"""Return the system prompt for the next model call."""
...
DynamicSystemPrompt: TypeAlias = (
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
)
class DynamicSystemPromptMiddleware(AgentMiddleware):
"""Dynamic System Prompt Middleware.
Allows setting the system prompt dynamically right before each model invocation.
Useful when the prompt depends on the current agent state or per-invocation context.
Example:
```python
from langchain.agents.middleware import DynamicSystemPromptMiddleware
class Context(TypedDict):
user_name: str
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
user_name = runtime.context.get("user_name", "n/a")
return (
f"You are a helpful assistant. Always address the user by their name: {user_name}"
)
middleware = DynamicSystemPromptMiddleware(system_prompt)
```
"""
_accepts_runtime: bool
def __init__(
self,
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
) -> None:
"""Initialize the dynamic system prompt middleware.
Args:
dynamic_system_prompt: Function that receives the current agent state
and optionally runtime with context, and returns the system prompt for
the next model call. Returns a string.
"""
super().__init__()
self.dynamic_system_prompt = dynamic_system_prompt
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
def modify_model_request(
self,
request: ModelRequest,
state: AgentState,
runtime: Runtime[ContextT],
) -> ModelRequest:
"""Modify the model request to include the dynamic system prompt."""
if self._accepts_runtime:
system_prompt = cast(
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
)(state, runtime)
else:
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
state
)
request.system_prompt = system_prompt
return request

View File

@@ -3,7 +3,20 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast from inspect import signature
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Generic,
Literal,
Protocol,
TypeAlias,
TypeGuard,
cast,
overload,
)
# needed as top level import for pydantic schema generation on AgentState # needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002 from langchain_core.messages import AnyMessage # noqa: TC002
@@ -14,9 +27,12 @@ from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar from typing_extensions import NotRequired, Required, TypedDict, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.structured_output import ResponseFormat from langchain.agents.structured_output import ResponseFormat
@@ -88,6 +104,7 @@ class PublicAgentState(TypedDict, Generic[ResponseT]):
StateT = TypeVar("StateT", bound=AgentState, default=AgentState) StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
class AgentMiddleware(Generic[StateT, ContextT]): class AgentMiddleware(Generic[StateT, ContextT]):
@@ -103,6 +120,12 @@ class AgentMiddleware(Generic[StateT, ContextT]):
tools: list[BaseTool] tools: list[BaseTool]
"""Additional tools registered by the middleware.""" """Additional tools registered by the middleware."""
before_model_jump_to: ClassVar[list[JumpTo]] = []
"""Valid jump destinations for before_model hook. Used to establish conditional edges."""
after_model_jump_to: ClassVar[list[JumpTo]] = []
"""Valid jump destinations for after_model hook. Used to establish conditional edges."""
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called.""" """Logic to run before the model is called."""
@@ -117,3 +140,404 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None: def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the model is called.""" """Logic to run after the model is called."""
class _CallableWithState(Protocol[StateT_contra]):
"""Callable with AgentState as argument."""
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
"""Perform some logic with the state."""
...
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with AgentState and Runtime as arguments."""
def __call__(
self, state: StateT_contra, runtime: Runtime[ContextT]
) -> dict[str, Any] | Command | None:
"""Perform some logic with the state and runtime."""
...
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
"""Callable with ModelRequest and AgentState as arguments."""
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
"""Perform some logic with the model request and state."""
...
class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with ModelRequest, AgentState, and Runtime as arguments."""
def __call__(
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
) -> ModelRequest:
"""Perform some logic with the model request, state, and runtime."""
...
_NodeSignature: TypeAlias = (
_CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
)
_ModelRequestSignature: TypeAlias = (
_CallableWithModelRequestAndState[StateT]
| _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
)
def is_callable_with_runtime(
func: _NodeSignature[StateT, ContextT],
) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]:
return "runtime" in signature(func).parameters
def is_callable_with_runtime_and_request(
func: _ModelRequestSignature[StateT, ContextT],
) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]:
return "runtime" in signature(func).parameters
@overload
def before_model(
func: _NodeSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def before_model(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
def before_model(
func: _NodeSignature[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the before_model hook.
Args:
func: The function to be decorated. Can accept either:
- `state: StateT` - Just the agent state
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "__end__"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
that can be applied to a function its wrapping.
The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control
Examples:
Basic usage with state only:
```python
@before_model
def log_before_model(state: AgentState) -> None:
print(f"About to call model with {len(state['messages'])} messages")
```
Advanced usage with runtime and conditional jumping:
```python
@before_model(jump_to=["__end__"])
def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if some_condition(state):
return {"jump_to": "__end__"}
return None
```
With custom state schema:
```python
@before_model(
state_schema=MyCustomState,
)
def custom_before_model(state: MyCustomState) -> dict[str, Any]:
return {"custom_field": "updated_value"}
```
"""
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
wrapped = wrapped_with_runtime
else:
def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return func(state) # type: ignore[call-arg]
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_model_jump_to": jump_to or [],
"before_model": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator
@overload
def modify_model_request(
func: _ModelRequestSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def modify_model_request(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
def modify_model_request(
func: _ModelRequestSignature[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
Args:
func: The function to be decorated. Can accept either:
- `request: ModelRequest, state: StateT` - Model request and agent state
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
Model request, state, and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided) or a decorator function
that can be applied to a function.
The decorated function should return:
- `ModelRequest` - The modified model request to be sent to the language model
Examples:
Basic usage to modify system prompt:
```python
@modify_model_request
def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
if request.system_prompt:
request.system_prompt += "\n\nAdditional context: ..."
else:
request.system_prompt = "Additional context: ..."
return request
```
Advanced usage with runtime and custom model settings:
```python
@modify_model_request
def dynamic_model_settings(
request: ModelRequest, state: AgentState, runtime: Runtime
) -> ModelRequest:
# Use a different model based on user subscription tier
if runtime.context.get("subscription_tier") == "premium":
request.model = "gpt-4o"
else:
request.model = "gpt-4o-mini"
return request
```
"""
def decorator(
func: _ModelRequestSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime_and_request(func):
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
return func(request, state, runtime)
wrapped = wrapped_with_runtime
else:
def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
) -> ModelRequest:
return func(request, state) # type: ignore[call-arg]
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast(
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"modify_model_request": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator
@overload
def after_model(
func: _NodeSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def after_model(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
def after_model(
func: _NodeSignature[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the after_model hook.
Args:
func: The function to be decorated. Can accept either:
- `state: StateT` - Just the agent state (includes model response)
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "__end__"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided) or a decorator function
that can be applied to a function.
The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control
Examples:
Basic usage for logging model responses:
```python
@after_model
def log_latest_message(state: AgentState) -> None:
print(state["messages"][-1].content)
```
With custom state schema:
```python
@after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
def custom_after_model(state: MyCustomState) -> dict[str, Any]:
return {"custom_field": "updated_after_model"}
```
"""
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
wrapped = wrapped_with_runtime
else:
def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return func(state) # type: ignore[call-arg]
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_model_jump_to": jump_to or [],
"after_model": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator

View File

@@ -464,7 +464,7 @@ def create_agent( # noqa: PLR0915
f"{middleware_w_after[0].__class__.__name__}.after_model", f"{middleware_w_after[0].__class__.__name__}.after_model",
END, END,
first_node, first_node,
tools_available=tool_node is not None, jump_to=middleware_w_after[0].after_model_jump_to,
) )
# Add middleware edges (same as before) # Add middleware edges (same as before)
@@ -475,7 +475,7 @@ def create_agent( # noqa: PLR0915
f"{m1.__class__.__name__}.before_model", f"{m1.__class__.__name__}.before_model",
f"{m2.__class__.__name__}.before_model", f"{m2.__class__.__name__}.before_model",
first_node, first_node,
tools_available=tool_node is not None, jump_to=m1.before_model_jump_to,
) )
# Go directly to model_request after the last before_model # Go directly to model_request after the last before_model
_add_middleware_edge( _add_middleware_edge(
@@ -483,7 +483,7 @@ def create_agent( # noqa: PLR0915
f"{middleware_w_before[-1].__class__.__name__}.before_model", f"{middleware_w_before[-1].__class__.__name__}.before_model",
"model_request", "model_request",
first_node, first_node,
tools_available=tool_node is not None, jump_to=middleware_w_before[-1].before_model_jump_to,
) )
if middleware_w_after: if middleware_w_after:
@@ -496,7 +496,7 @@ def create_agent( # noqa: PLR0915
f"{m1.__class__.__name__}.after_model", f"{m1.__class__.__name__}.after_model",
f"{m2.__class__.__name__}.after_model", f"{m2.__class__.__name__}.after_model",
first_node, first_node,
tools_available=tool_node is not None, jump_to=m1.after_model_jump_to,
) )
return graph return graph
@@ -584,7 +584,7 @@ def _add_middleware_edge(
name: str, name: str,
default_destination: str, default_destination: str,
model_destination: str, model_destination: str,
tools_available: bool, # noqa: FBT001 jump_to: list[JumpTo] | None,
) -> None: ) -> None:
"""Add an edge to the graph for a middleware node. """Add an edge to the graph for a middleware node.
@@ -594,18 +594,23 @@ def _add_middleware_edge(
name: The name of the middleware node. name: The name of the middleware node.
default_destination: The default destination for the edge. default_destination: The default destination for the edge.
model_destination: The destination for the edge to the model. model_destination: The destination for the edge to the model.
tools_available: Whether tools are available for the edge to potentially route to. jump_to: The conditionally jumpable destinations for the edge.
""" """
if jump_to:
def jump_edge(state: AgentState) -> str: def jump_edge(state: AgentState) -> str:
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
destinations = [default_destination] destinations = [default_destination]
if default_destination != END:
destinations.append(END)
if tools_available:
destinations.append("tools")
if name != model_destination:
destinations.append(model_destination)
graph.add_conditional_edges(name, jump_edge, destinations) if "__end__" in jump_to:
destinations.append(END)
if "tools" in jump_to:
destinations.append("tools")
if "model" in jump_to and name != model_destination:
destinations.append(model_destination)
graph.add_conditional_edges(name, jump_edge, destinations)
else:
graph.add_edge(name, default_destination)

View File

@@ -30,8 +30,7 @@
model_request(model_request) model_request(model_request)
NoopOne\2ebefore_model(NoopOne.before_model) NoopOne\2ebefore_model(NoopOne.before_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model -.-> __end__; NoopOne\2ebefore_model --> model_request;
NoopOne\2ebefore_model -.-> model_request;
__start__ --> NoopOne\2ebefore_model; __start__ --> NoopOne\2ebefore_model;
model_request --> __end__; model_request --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
@@ -53,12 +52,10 @@
NoopTen\2ebefore_model(NoopTen.before_model) NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model) NoopTen\2eafter_model(NoopTen.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopTen\2eafter_model -.-> NoopTen\2ebefore_model; NoopTen\2ebefore_model --> model_request;
NoopTen\2eafter_model -.-> __end__;
NoopTen\2ebefore_model -.-> __end__;
NoopTen\2ebefore_model -.-> model_request;
__start__ --> NoopTen\2ebefore_model; __start__ --> NoopTen\2ebefore_model;
model_request --> NoopTen\2eafter_model; model_request --> NoopTen\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -80,18 +77,12 @@
NoopEleven\2ebefore_model(NoopEleven.before_model) NoopEleven\2ebefore_model(NoopEleven.before_model)
NoopEleven\2eafter_model(NoopEleven.after_model) NoopEleven\2eafter_model(NoopEleven.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEleven\2eafter_model -.-> NoopTen\2eafter_model; NoopEleven\2eafter_model --> NoopTen\2eafter_model;
NoopEleven\2eafter_model -.-> NoopTen\2ebefore_model; NoopEleven\2ebefore_model --> model_request;
NoopEleven\2eafter_model -.-> __end__; NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
NoopEleven\2ebefore_model -.-> NoopTen\2ebefore_model;
NoopEleven\2ebefore_model -.-> __end__;
NoopEleven\2ebefore_model -.-> model_request;
NoopTen\2eafter_model -.-> NoopTen\2ebefore_model;
NoopTen\2eafter_model -.-> __end__;
NoopTen\2ebefore_model -.-> NoopEleven\2ebefore_model;
NoopTen\2ebefore_model -.-> __end__;
__start__ --> NoopTen\2ebefore_model; __start__ --> NoopTen\2ebefore_model;
model_request --> NoopEleven\2eafter_model; model_request --> NoopEleven\2eafter_model;
NoopTen\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -111,11 +102,8 @@
NoopOne\2ebefore_model(NoopOne.before_model) NoopOne\2ebefore_model(NoopOne.before_model)
NoopTwo\2ebefore_model(NoopTwo.before_model) NoopTwo\2ebefore_model(NoopTwo.before_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model -.-> NoopTwo\2ebefore_model; NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopOne\2ebefore_model -.-> __end__; NoopTwo\2ebefore_model --> model_request;
NoopTwo\2ebefore_model -.-> NoopOne\2ebefore_model;
NoopTwo\2ebefore_model -.-> __end__;
NoopTwo\2ebefore_model -.-> model_request;
__start__ --> NoopOne\2ebefore_model; __start__ --> NoopOne\2ebefore_model;
model_request --> __end__; model_request --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
@@ -138,14 +126,9 @@
NoopTwo\2ebefore_model(NoopTwo.before_model) NoopTwo\2ebefore_model(NoopTwo.before_model)
NoopThree\2ebefore_model(NoopThree.before_model) NoopThree\2ebefore_model(NoopThree.before_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopOne\2ebefore_model -.-> NoopTwo\2ebefore_model; NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
NoopOne\2ebefore_model -.-> __end__; NoopThree\2ebefore_model --> model_request;
NoopThree\2ebefore_model -.-> NoopOne\2ebefore_model; NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
NoopThree\2ebefore_model -.-> __end__;
NoopThree\2ebefore_model -.-> model_request;
NoopTwo\2ebefore_model -.-> NoopOne\2ebefore_model;
NoopTwo\2ebefore_model -.-> NoopThree\2ebefore_model;
NoopTwo\2ebefore_model -.-> __end__;
__start__ --> NoopOne\2ebefore_model; __start__ --> NoopOne\2ebefore_model;
model_request --> __end__; model_request --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
@@ -166,10 +149,9 @@
model_request(model_request) model_request(model_request)
NoopFour\2eafter_model(NoopFour.after_model) NoopFour\2eafter_model(NoopFour.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopFour\2eafter_model -.-> __end__;
NoopFour\2eafter_model -.-> model_request;
__start__ --> model_request; __start__ --> model_request;
model_request --> NoopFour\2eafter_model; model_request --> NoopFour\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -189,13 +171,10 @@
NoopFour\2eafter_model(NoopFour.after_model) NoopFour\2eafter_model(NoopFour.after_model)
NoopFive\2eafter_model(NoopFive.after_model) NoopFive\2eafter_model(NoopFive.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopFive\2eafter_model -.-> NoopFour\2eafter_model; NoopFive\2eafter_model --> NoopFour\2eafter_model;
NoopFive\2eafter_model -.-> __end__;
NoopFive\2eafter_model -.-> model_request;
NoopFour\2eafter_model -.-> __end__;
NoopFour\2eafter_model -.-> model_request;
__start__ --> model_request; __start__ --> model_request;
model_request --> NoopFive\2eafter_model; model_request --> NoopFive\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -216,16 +195,11 @@
NoopFive\2eafter_model(NoopFive.after_model) NoopFive\2eafter_model(NoopFive.after_model)
NoopSix\2eafter_model(NoopSix.after_model) NoopSix\2eafter_model(NoopSix.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopFive\2eafter_model -.-> NoopFour\2eafter_model; NoopFive\2eafter_model --> NoopFour\2eafter_model;
NoopFive\2eafter_model -.-> __end__; NoopSix\2eafter_model --> NoopFive\2eafter_model;
NoopFive\2eafter_model -.-> model_request;
NoopFour\2eafter_model -.-> __end__;
NoopFour\2eafter_model -.-> model_request;
NoopSix\2eafter_model -.-> NoopFive\2eafter_model;
NoopSix\2eafter_model -.-> __end__;
NoopSix\2eafter_model -.-> model_request;
__start__ --> model_request; __start__ --> model_request;
model_request --> NoopSix\2eafter_model; model_request --> NoopSix\2eafter_model;
NoopFour\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -245,12 +219,10 @@
NoopSeven\2ebefore_model(NoopSeven.before_model) NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model) NoopSeven\2eafter_model(NoopSeven.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2ebefore_model --> model_request;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2ebefore_model -.-> __end__;
NoopSeven\2ebefore_model -.-> model_request;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopSeven\2eafter_model; model_request --> NoopSeven\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -272,18 +244,12 @@
NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model) NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; NoopEight\2ebefore_model --> model_request;
NoopEight\2eafter_model -.-> __end__; NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model_request;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
NoopSeven\2ebefore_model -.-> __end__;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopEight\2eafter_model; model_request --> NoopEight\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -307,24 +273,14 @@
NoopNine\2ebefore_model(NoopNine.before_model) NoopNine\2ebefore_model(NoopNine.before_model)
NoopNine\2eafter_model(NoopNine.after_model) NoopNine\2eafter_model(NoopNine.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model; NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
NoopEight\2eafter_model -.-> __end__; NoopNine\2eafter_model --> NoopEight\2eafter_model;
NoopEight\2ebefore_model -.-> NoopNine\2ebefore_model; NoopNine\2ebefore_model --> model_request;
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model; NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
NoopEight\2ebefore_model -.-> __end__;
NoopNine\2eafter_model -.-> NoopEight\2eafter_model;
NoopNine\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopNine\2eafter_model -.-> __end__;
NoopNine\2ebefore_model -.-> NoopSeven\2ebefore_model;
NoopNine\2ebefore_model -.-> __end__;
NoopNine\2ebefore_model -.-> model_request;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
NoopSeven\2ebefore_model -.-> __end__;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopNine\2eafter_model; model_request --> NoopNine\2eafter_model;
NoopSeven\2eafter_model --> __end__;
classDef default fill:#f2f0ff,line-height:1.2 classDef default fill:#f2f0ff,line-height:1.2
classDef first fill-opacity:0 classDef first fill-opacity:0
classDef last fill:#bfb6fc classDef last fill:#bfb6fc
@@ -347,20 +303,13 @@
NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model) NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopEight\2eafter_model -.-> __end__;
NoopEight\2eafter_model -.-> tools;
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model_request; NoopEight\2ebefore_model -.-> model_request;
NoopEight\2ebefore_model -.-> tools;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools; NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
NoopSeven\2ebefore_model -.-> __end__;
NoopSeven\2ebefore_model -.-> tools;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopEight\2eafter_model; model_request --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model; tools -.-> NoopSeven\2ebefore_model;
@@ -387,20 +336,13 @@
NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model) NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopEight\2eafter_model -.-> __end__;
NoopEight\2eafter_model -.-> tools;
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model_request; NoopEight\2ebefore_model -.-> model_request;
NoopEight\2ebefore_model -.-> tools;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools; NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
NoopSeven\2ebefore_model -.-> __end__;
NoopSeven\2ebefore_model -.-> tools;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopEight\2eafter_model; model_request --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model; tools -.-> NoopSeven\2ebefore_model;
@@ -427,20 +369,13 @@
NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model) NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopEight\2eafter_model -.-> __end__;
NoopEight\2eafter_model -.-> tools;
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model_request; NoopEight\2ebefore_model -.-> model_request;
NoopEight\2ebefore_model -.-> tools;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools; NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
NoopSeven\2ebefore_model -.-> __end__;
NoopSeven\2ebefore_model -.-> tools;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopEight\2eafter_model; model_request --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model; tools -.-> NoopSeven\2ebefore_model;
@@ -467,20 +402,13 @@
NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model) NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopEight\2eafter_model -.-> __end__;
NoopEight\2eafter_model -.-> tools;
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model_request; NoopEight\2ebefore_model -.-> model_request;
NoopEight\2ebefore_model -.-> tools;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools; NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
NoopSeven\2ebefore_model -.-> __end__;
NoopSeven\2ebefore_model -.-> tools;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopEight\2eafter_model; model_request --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model; tools -.-> NoopSeven\2ebefore_model;
@@ -507,20 +435,13 @@
NoopEight\2ebefore_model(NoopEight.before_model) NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model) NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last __end__([<p>__end__</p>]):::last
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model; NoopEight\2eafter_model --> NoopSeven\2eafter_model;
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopEight\2eafter_model -.-> __end__;
NoopEight\2eafter_model -.-> tools;
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
NoopEight\2ebefore_model -.-> __end__; NoopEight\2ebefore_model -.-> __end__;
NoopEight\2ebefore_model -.-> model_request; NoopEight\2ebefore_model -.-> model_request;
NoopEight\2ebefore_model -.-> tools;
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model; NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
NoopSeven\2eafter_model -.-> __end__; NoopSeven\2eafter_model -.-> __end__;
NoopSeven\2eafter_model -.-> tools; NoopSeven\2eafter_model -.-> tools;
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model; NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
NoopSeven\2ebefore_model -.-> __end__;
NoopSeven\2ebefore_model -.-> tools;
__start__ --> NoopSeven\2ebefore_model; __start__ --> NoopSeven\2ebefore_model;
model_request --> NoopEight\2eafter_model; model_request --> NoopEight\2eafter_model;
tools -.-> NoopSeven\2ebefore_model; tools -.-> NoopSeven\2ebefore_model;

View File

@@ -34,7 +34,6 @@ from langchain.agents.middleware.types import (
OmitFromOutput, OmitFromOutput,
PrivateStateAttr, PrivateStateAttr,
) )
from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware
from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -332,6 +331,8 @@ def test_create_agent_jump(
calls.append("NoopSeven.after_model") calls.append("NoopSeven.after_model")
class NoopEight(AgentMiddleware): class NoopEight(AgentMiddleware):
before_model_jump_to = [END]
def before_model(self, state) -> dict[str, Any]: def before_model(self, state) -> dict[str, Any]:
calls.append("NoopEight.before_model") calls.append("NoopEight.before_model")
return {"jump_to": END} return {"jump_to": END}
@@ -1221,74 +1222,6 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
assert result["response"].condition == "sunny" assert result["response"].condition == "sunny"
# Tests for DynamicSystemPromptMiddleware
def test_dynamic_system_prompt_middleware_basic() -> None:
"""Test basic functionality of DynamicSystemPromptMiddleware."""
def dynamic_system_prompt(state: AgentState) -> str:
messages = state.get("messages", [])
if messages:
return f"You are a helpful assistant. Message count: {len(messages)}"
return "You are a helpful assistant. No messages yet."
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
# Test with empty state
empty_state = {"messages": []}
request = ModelRequest(
model=FakeToolCallingModel(),
system_prompt="Original prompt",
messages=[],
tool_choice=None,
tools=[],
response_format=None,
)
modified_request = middleware.modify_model_request(request, empty_state, None)
assert modified_request.system_prompt == "You are a helpful assistant. No messages yet."
state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]}
modified_request = middleware.modify_model_request(request, state_with_messages, None)
assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2"
def test_dynamic_system_prompt_middleware_with_context() -> None:
"""Test DynamicSystemPromptMiddleware with runtime context."""
class MockContext(TypedDict):
user_role: str
def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str:
base_prompt = "You are a helpful assistant."
if runtime and hasattr(runtime, "context"):
user_role = runtime.context.get("user_role", "user")
return f"{base_prompt} User role: {user_role}"
return base_prompt
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
# Create a mock runtime with context
class MockRuntime:
def __init__(self, context):
self.context = context
mock_runtime = MockRuntime(context={"user_role": "admin"})
request = ModelRequest(
model=FakeToolCallingModel(),
system_prompt="Original prompt",
messages=[HumanMessage("Test")],
tool_choice=None,
tools=[],
response_format=None,
)
state = {"messages": [HumanMessage("Test")]}
modified_request = middleware.modify_model_request(request, state, mock_runtime)
assert modified_request.system_prompt == "You are a helpful assistant. User role: admin"
def test_public_private_state_for_custom_middleware() -> None: def test_public_private_state_for_custom_middleware() -> None:
"""Test public and private state for custom middleware.""" """Test public and private state for custom middleware."""

View File

@@ -0,0 +1,152 @@
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
from typing import Any
from typing_extensions import NotRequired
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langgraph.types import Command
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
before_model,
after_model,
modify_model_request,
)
from langchain.agents.middleware_agent import create_agent
from .model import FakeToolCallingModel
class CustomState(AgentState):
"""Custom state schema for testing."""
custom_field: NotRequired[str]
@tool
def test_tool(input: str) -> str:
"""A test tool for middleware testing."""
return f"Tool result: {input}"
def test_before_model_decorator() -> None:
"""Test before_model decorator with all configuration options."""
@before_model(
state_schema=CustomState, tools=[test_tool], jump_to=["__end__"], name="CustomBeforeModel"
)
def custom_before_model(state: CustomState) -> dict[str, Any]:
return {"jump_to": "__end__"}
assert isinstance(custom_before_model, AgentMiddleware)
assert custom_before_model.state_schema == CustomState
assert custom_before_model.tools == [test_tool]
assert custom_before_model.before_model_jump_to == ["__end__"]
assert custom_before_model.__class__.__name__ == "CustomBeforeModel"
result = custom_before_model.before_model({"messages": [HumanMessage("Hello")]})
assert result == {"jump_to": "__end__"}
def test_after_model_decorator() -> None:
"""Test after_model decorator with all configuration options."""
@after_model(
state_schema=CustomState,
tools=[test_tool],
jump_to=["model", "__end__"],
name="CustomAfterModel",
)
def custom_after_model(state: CustomState) -> dict[str, Any]:
return {"jump_to": "model"}
# Verify all options were applied
assert isinstance(custom_after_model, AgentMiddleware)
assert custom_after_model.state_schema == CustomState
assert custom_after_model.tools == [test_tool]
assert custom_after_model.after_model_jump_to == ["model", "__end__"]
assert custom_after_model.__class__.__name__ == "CustomAfterModel"
# Verify it works
result = custom_after_model.after_model({"messages": [HumanMessage("Hello"), AIMessage("Hi!")]})
assert result == {"jump_to": "model"}
def test_modify_model_request_decorator() -> None:
"""Test modify_model_request decorator with all configuration options."""
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="CustomModifyRequest")
def custom_modify_request(request: ModelRequest, state: CustomState) -> ModelRequest:
request.system_prompt = "Modified"
return request
# Verify all options were applied
assert isinstance(custom_modify_request, AgentMiddleware)
assert custom_modify_request.state_schema == CustomState
assert custom_modify_request.tools == [test_tool]
assert custom_modify_request.__class__.__name__ == "CustomModifyRequest"
# Verify it works
original_request = ModelRequest(
model="test-model",
system_prompt="Original",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=[],
response_format=None,
)
result = custom_modify_request.modify_model_request(
original_request, {"messages": [HumanMessage("Hello")]}
)
assert result.system_prompt == "Modified"
def test_all_decorators_integration() -> None:
"""Test all three decorators working together in an agent."""
call_order = []
@before_model
def track_before(state: AgentState) -> None:
call_order.append("before")
return None
@modify_model_request
def track_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("modify")
return request
@after_model
def track_after(state: AgentState) -> None:
call_order.append("after")
return None
agent = create_agent(
model=FakeToolCallingModel(), middleware=[track_before, track_modify, track_after]
)
agent = agent.compile()
agent.invoke({"messages": [HumanMessage("Hello")]})
assert call_order == ["before", "modify", "after"]
def test_decorators_use_function_names_as_default() -> None:
"""Test that decorators use function names as default middleware names."""
@before_model
def my_before_hook(state: AgentState) -> None:
return None
@modify_model_request
def my_modify_hook(request: ModelRequest, state: AgentState) -> ModelRequest:
return request
@after_model
def my_after_hook(state: AgentState) -> None:
return None
# Verify that function names are used as middleware class names
assert my_before_hook.__class__.__name__ == "my_before_hook"
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
assert my_after_hook.__class__.__name__ == "my_after_hook"