mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 12:01:54 +00:00
feat(langchain): new decorator pattern for dynamically generated middleware (#33053)
# Main Changes 1. Adding decorator utilities for dynamically defining middleware with single hook functions (see an example below for dynamic system prompt) 2. Adding better conditional edge drawing with jump configuration attached to middleware. Can be registered w/ the decorator new decorator! ## Decorator Utilities ```py from langchain.agents.middleware_agent import create_agent, AgentState, ModelRequest from langchain.agents.middleware.types import modify_model_request from langchain_core.messages import HumanMessage from langgraph.checkpoint.memory import InMemorySaver @modify_model_request def modify_system_prompt(request: ModelRequest, state: AgentState) -> ModelRequest: request.system_prompt = ( "You are a helpful assistant." f"Please record the number of previous messages in your response: {len(state['messages'])}" ) return request agent = create_agent( model="openai:gpt-4o-mini", middleware=[modify_system_prompt] ).compile(checkpointer=InMemorySaver()) ``` ## Visualization and Routing improvements We now require that middlewares define the valid jumps for each hook. If using the new decorator syntax, this can be done with: ```py @before_model(jump_to=["__end__"]) @after_model(jump_to=["tools", "__end__"]) ``` If using the subclassing syntax, you can use these two class vars: ```py class MyMiddlewareAgentMiddleware): before_model_jump_to = ["__end__"] after_model_jump_to = ["tools", "__end__"] ``` Open for debate if we want to bundle these in a single jump map / config for a middleware. Easy to migrate later if we decide to add more hooks. We will need to **really clearly document** that these must be explicitly set in order to enable conditional edges. Notice for the below case, `Middleware2` does actually enable jumps. <table> <thead> <tr> <th>Before (broken), adding conditional edges unconditionally</th> <th>After (fixed), adding conditional edges sparingly</th> </tr> </thead> <tbody> <tr> <td> <img width="619" height="508" alt="Screenshot 2025-09-23 at 10 23 23 AM" src="https://github.com/user-attachments/assets/bba2d098-a839-4335-8e8c-b50dd8090959" /> </td> <td> <img width="469" height="490" alt="Screenshot 2025-09-23 at 10 23 13 AM" src="https://github.com/user-attachments/assets/717abf0b-fc73-4d5f-9313-b81247d8fe26" /> </td> </tr> </tbody> </table> <details> <summary>Snippet for the above</summary> ```py from typing import Any from langchain.agents.tool_node import InjectedState from langgraph.runtime import Runtime from langchain.agents.middleware.types import AgentMiddleware, AgentState from langchain.agents.middleware_agent import create_agent from langchain_core.tools import tool from typing import Annotated from langchain_core.messages import HumanMessage from typing_extensions import NotRequired @tool def simple_tool(input: str) -> str: """A simple tool.""" return "successful tool call" class Middleware1(AgentMiddleware): """Custom middleware that adds a simple tool.""" tools = [simple_tool] def before_model(self, state: AgentState, runtime: Runtime) -> None: return None def after_model(self, state: AgentState, runtime: Runtime) -> None: return None class Middleware2(AgentMiddleware): before_model_jump_to = ["tools", "__end__"] def before_model(self, state: AgentState, runtime: Runtime) -> None: return None def after_model(self, state: AgentState, runtime: Runtime) -> None: return None class Middleware3(AgentMiddleware): def before_model(self, state: AgentState, runtime: Runtime) -> None: return None def after_model(self, state: AgentState, runtime: Runtime) -> None: return None builder = create_agent( model="openai:gpt-4o-mini", middleware=[Middleware1(), Middleware2(), Middleware3()], system_prompt="You are a helpful assistant.", ) agent = builder.compile() ``` </details> ## More Examples ### Guardrails `after_model` <img width="379" height="335" alt="Screenshot 2025-09-23 at 10 40 09 AM" src="https://github.com/user-attachments/assets/45bac7dd-398e-45d1-ae58-6ecfa27dfc87" /> <details> <summary>Code</summary> ```py from langchain.agents.middleware_agent import create_agent, AgentState, ModelRequest from langchain.agents.middleware.types import after_model from langchain_core.messages import HumanMessage, AIMessage from langgraph.checkpoint.memory import InMemorySaver from typing import cast, Any @after_model(jump_to=["model", "__end__"]) def after_model_hook(state: AgentState) -> dict[str, Any]: """Check the last AI message for safety violations.""" last_message_content = cast(AIMessage, state["messages"][-1]).content.lower() print(last_message_content) unsafe_keywords = ["pineapple"] if any(keyword in last_message_content for keyword in unsafe_keywords): # Jump back to model to regenerate response return {"jump_to": "model", "messages": [HumanMessage("Please regenerate your response, and don't talk about pineapples. You can talk about apples instead.")]} return {"jump_to": "__end__"} # Create agent with guardrails middleware agent = create_agent( model="openai:gpt-4o-mini", middleware=[after_model_hook], system_prompt="Keep your responses to one sentence please!" ).compile() # Test with potentially unsafe input result = agent.invoke( {"messages": [HumanMessage("Tell me something about pineapples")]}, ) for msg in result["messages"]: print(msg.pretty_print()) """ ================================ Human Message ================================= Tell me something about pineapples None ================================== Ai Message ================================== Pineapples are tropical fruits known for their sweet, tangy flavor and distinctive spiky exterior. None ================================ Human Message ================================= Please regenerate your response, and don't talk about pineapples. You can talk about apples instead. None ================================== Ai Message ================================== Apples are popular fruits that come in various varieties, known for their crisp texture and sweetness, and are often used in cooking and baking. None """ ``` </details>
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
"""Middleware plugins for agents."""
|
"""Middleware plugins for agents."""
|
||||||
|
|
||||||
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
|
|
||||||
from .human_in_the_loop import HumanInTheLoopMiddleware
|
from .human_in_the_loop import HumanInTheLoopMiddleware
|
||||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||||
from .summarization import SummarizationMiddleware
|
from .summarization import SummarizationMiddleware
|
||||||
@@ -11,7 +10,6 @@ __all__ = [
|
|||||||
"AgentState",
|
"AgentState",
|
||||||
# should move to langchain-anthropic if we decide to keep it
|
# should move to langchain-anthropic if we decide to keep it
|
||||||
"AnthropicPromptCachingMiddleware",
|
"AnthropicPromptCachingMiddleware",
|
||||||
"DynamicSystemPromptMiddleware",
|
|
||||||
"HumanInTheLoopMiddleware",
|
"HumanInTheLoopMiddleware",
|
||||||
"ModelRequest",
|
"ModelRequest",
|
||||||
"SummarizationMiddleware",
|
"SummarizationMiddleware",
|
||||||
|
@@ -1,105 +0,0 @@
|
|||||||
"""Dynamic System Prompt Middleware.
|
|
||||||
|
|
||||||
Allows setting the system prompt dynamically right before each model invocation.
|
|
||||||
Useful when the prompt depends on the current agent state or per-invocation context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from inspect import signature
|
|
||||||
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
|
|
||||||
|
|
||||||
from langgraph.typing import ContextT
|
|
||||||
|
|
||||||
from langchain.agents.middleware.types import (
|
|
||||||
AgentMiddleware,
|
|
||||||
AgentState,
|
|
||||||
ModelRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from langgraph.runtime import Runtime
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicSystemPromptWithoutRuntime(Protocol):
|
|
||||||
"""Dynamic system prompt without runtime in call signature."""
|
|
||||||
|
|
||||||
def __call__(self, state: AgentState) -> str:
|
|
||||||
"""Return the system prompt for the next model call."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
|
|
||||||
"""Dynamic system prompt with runtime in call signature."""
|
|
||||||
|
|
||||||
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
|
||||||
"""Return the system prompt for the next model call."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
DynamicSystemPrompt: TypeAlias = (
|
|
||||||
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicSystemPromptMiddleware(AgentMiddleware):
|
|
||||||
"""Dynamic System Prompt Middleware.
|
|
||||||
|
|
||||||
Allows setting the system prompt dynamically right before each model invocation.
|
|
||||||
Useful when the prompt depends on the current agent state or per-invocation context.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
from langchain.agents.middleware import DynamicSystemPromptMiddleware
|
|
||||||
|
|
||||||
|
|
||||||
class Context(TypedDict):
|
|
||||||
user_name: str
|
|
||||||
|
|
||||||
|
|
||||||
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
|
|
||||||
user_name = runtime.context.get("user_name", "n/a")
|
|
||||||
return (
|
|
||||||
f"You are a helpful assistant. Always address the user by their name: {user_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
middleware = DynamicSystemPromptMiddleware(system_prompt)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
_accepts_runtime: bool
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the dynamic system prompt middleware.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dynamic_system_prompt: Function that receives the current agent state
|
|
||||||
and optionally runtime with context, and returns the system prompt for
|
|
||||||
the next model call. Returns a string.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.dynamic_system_prompt = dynamic_system_prompt
|
|
||||||
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
|
|
||||||
|
|
||||||
def modify_model_request(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
state: AgentState,
|
|
||||||
runtime: Runtime[ContextT],
|
|
||||||
) -> ModelRequest:
|
|
||||||
"""Modify the model request to include the dynamic system prompt."""
|
|
||||||
if self._accepts_runtime:
|
|
||||||
system_prompt = cast(
|
|
||||||
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
|
|
||||||
)(state, runtime)
|
|
||||||
else:
|
|
||||||
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
|
|
||||||
state
|
|
||||||
)
|
|
||||||
|
|
||||||
request.system_prompt = system_prompt
|
|
||||||
return request
|
|
@@ -3,7 +3,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
|
from inspect import signature
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Annotated,
|
||||||
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Generic,
|
||||||
|
Literal,
|
||||||
|
Protocol,
|
||||||
|
TypeAlias,
|
||||||
|
TypeGuard,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
# needed as top level import for pydantic schema generation on AgentState
|
# needed as top level import for pydantic schema generation on AgentState
|
||||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||||
@@ -14,9 +27,12 @@ from langgraph.typing import ContextT
|
|||||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
from langchain.agents.structured_output import ResponseFormat
|
from langchain.agents.structured_output import ResponseFormat
|
||||||
|
|
||||||
@@ -88,6 +104,7 @@ class PublicAgentState(TypedDict, Generic[ResponseT]):
|
|||||||
|
|
||||||
|
|
||||||
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
||||||
|
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
class AgentMiddleware(Generic[StateT, ContextT]):
|
class AgentMiddleware(Generic[StateT, ContextT]):
|
||||||
@@ -103,6 +120,12 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|||||||
tools: list[BaseTool]
|
tools: list[BaseTool]
|
||||||
"""Additional tools registered by the middleware."""
|
"""Additional tools registered by the middleware."""
|
||||||
|
|
||||||
|
before_model_jump_to: ClassVar[list[JumpTo]] = []
|
||||||
|
"""Valid jump destinations for before_model hook. Used to establish conditional edges."""
|
||||||
|
|
||||||
|
after_model_jump_to: ClassVar[list[JumpTo]] = []
|
||||||
|
"""Valid jump destinations for after_model hook. Used to establish conditional edges."""
|
||||||
|
|
||||||
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||||
"""Logic to run before the model is called."""
|
"""Logic to run before the model is called."""
|
||||||
|
|
||||||
@@ -117,3 +140,404 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|||||||
|
|
||||||
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||||
"""Logic to run after the model is called."""
|
"""Logic to run after the model is called."""
|
||||||
|
|
||||||
|
|
||||||
|
class _CallableWithState(Protocol[StateT_contra]):
|
||||||
|
"""Callable with AgentState as argument."""
|
||||||
|
|
||||||
|
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
|
||||||
|
"""Perform some logic with the state."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||||
|
"""Callable with AgentState and Runtime as arguments."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, state: StateT_contra, runtime: Runtime[ContextT]
|
||||||
|
) -> dict[str, Any] | Command | None:
|
||||||
|
"""Perform some logic with the state and runtime."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
|
||||||
|
"""Callable with ModelRequest and AgentState as arguments."""
|
||||||
|
|
||||||
|
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
|
||||||
|
"""Perform some logic with the model request and state."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||||
|
"""Callable with ModelRequest, AgentState, and Runtime as arguments."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
||||||
|
) -> ModelRequest:
|
||||||
|
"""Perform some logic with the model request, state, and runtime."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
_NodeSignature: TypeAlias = (
|
||||||
|
_CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
|
||||||
|
)
|
||||||
|
_ModelRequestSignature: TypeAlias = (
|
||||||
|
_CallableWithModelRequestAndState[StateT]
|
||||||
|
| _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_callable_with_runtime(
|
||||||
|
func: _NodeSignature[StateT, ContextT],
|
||||||
|
) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]:
|
||||||
|
return "runtime" in signature(func).parameters
|
||||||
|
|
||||||
|
|
||||||
|
def is_callable_with_runtime_and_request(
|
||||||
|
func: _ModelRequestSignature[StateT, ContextT],
|
||||||
|
) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]:
|
||||||
|
return "runtime" in signature(func).parameters
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def before_model(
|
||||||
|
func: _NodeSignature[StateT, ContextT],
|
||||||
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def before_model(
|
||||||
|
func: None = None,
|
||||||
|
*,
|
||||||
|
state_schema: type[StateT] | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
jump_to: list[JumpTo] | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def before_model(
|
||||||
|
func: _NodeSignature[StateT, ContextT] | None = None,
|
||||||
|
*,
|
||||||
|
state_schema: type[StateT] | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
jump_to: list[JumpTo] | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> (
|
||||||
|
Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||||
|
| AgentMiddleware[StateT, ContextT]
|
||||||
|
):
|
||||||
|
"""Decorator used to dynamically create a middleware with the before_model hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to be decorated. Can accept either:
|
||||||
|
- `state: StateT` - Just the agent state
|
||||||
|
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
||||||
|
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||||
|
AgentState schema.
|
||||||
|
tools: Optional list of additional tools to register with this middleware.
|
||||||
|
jump_to: Optional list of valid jump destinations for conditional edges.
|
||||||
|
Valid values are: "tools", "model", "__end__"
|
||||||
|
name: Optional name for the generated middleware class. If not provided,
|
||||||
|
uses the decorated function's name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
||||||
|
that can be applied to a function its wrapping.
|
||||||
|
|
||||||
|
The decorated function should return:
|
||||||
|
- `dict[str, Any]` - State updates to merge into the agent state
|
||||||
|
- `Command` - A command to control flow (e.g., jump to different node)
|
||||||
|
- `None` - No state updates or flow control
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic usage with state only:
|
||||||
|
```python
|
||||||
|
@before_model
|
||||||
|
def log_before_model(state: AgentState) -> None:
|
||||||
|
print(f"About to call model with {len(state['messages'])} messages")
|
||||||
|
```
|
||||||
|
|
||||||
|
Advanced usage with runtime and conditional jumping:
|
||||||
|
```python
|
||||||
|
@before_model(jump_to=["__end__"])
|
||||||
|
def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||||
|
if some_condition(state):
|
||||||
|
return {"jump_to": "__end__"}
|
||||||
|
return None
|
||||||
|
```
|
||||||
|
|
||||||
|
With custom state schema:
|
||||||
|
```python
|
||||||
|
@before_model(
|
||||||
|
state_schema=MyCustomState,
|
||||||
|
)
|
||||||
|
def custom_before_model(state: MyCustomState) -> dict[str, Any]:
|
||||||
|
return {"custom_field": "updated_value"}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||||
|
if is_callable_with_runtime(func):
|
||||||
|
|
||||||
|
def wrapped_with_runtime(
|
||||||
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||||
|
state: StateT,
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | Command | None:
|
||||||
|
return func(state, runtime)
|
||||||
|
|
||||||
|
wrapped = wrapped_with_runtime
|
||||||
|
else:
|
||||||
|
|
||||||
|
def wrapped_without_runtime(
|
||||||
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||||
|
state: StateT,
|
||||||
|
) -> dict[str, Any] | Command | None:
|
||||||
|
return func(state) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Use function name as default if no name provided
|
||||||
|
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
|
||||||
|
|
||||||
|
return type(
|
||||||
|
middleware_name,
|
||||||
|
(AgentMiddleware,),
|
||||||
|
{
|
||||||
|
"state_schema": state_schema or AgentState,
|
||||||
|
"tools": tools or [],
|
||||||
|
"before_model_jump_to": jump_to or [],
|
||||||
|
"before_model": wrapped,
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def modify_model_request(
|
||||||
|
func: _ModelRequestSignature[StateT, ContextT],
|
||||||
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def modify_model_request(
|
||||||
|
func: None = None,
|
||||||
|
*,
|
||||||
|
state_schema: type[StateT] | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def modify_model_request(
|
||||||
|
func: _ModelRequestSignature[StateT, ContextT] | None = None,
|
||||||
|
*,
|
||||||
|
state_schema: type[StateT] | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> (
|
||||||
|
Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||||
|
| AgentMiddleware[StateT, ContextT]
|
||||||
|
):
|
||||||
|
r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to be decorated. Can accept either:
|
||||||
|
- `request: ModelRequest, state: StateT` - Model request and agent state
|
||||||
|
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
||||||
|
Model request, state, and runtime context
|
||||||
|
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||||
|
AgentState schema.
|
||||||
|
tools: Optional list of additional tools to register with this middleware.
|
||||||
|
name: Optional name for the generated middleware class. If not provided,
|
||||||
|
uses the decorated function's name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
||||||
|
that can be applied to a function.
|
||||||
|
|
||||||
|
The decorated function should return:
|
||||||
|
- `ModelRequest` - The modified model request to be sent to the language model
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic usage to modify system prompt:
|
||||||
|
```python
|
||||||
|
@modify_model_request
|
||||||
|
def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||||
|
if request.system_prompt:
|
||||||
|
request.system_prompt += "\n\nAdditional context: ..."
|
||||||
|
else:
|
||||||
|
request.system_prompt = "Additional context: ..."
|
||||||
|
return request
|
||||||
|
```
|
||||||
|
|
||||||
|
Advanced usage with runtime and custom model settings:
|
||||||
|
```python
|
||||||
|
@modify_model_request
|
||||||
|
def dynamic_model_settings(
|
||||||
|
request: ModelRequest, state: AgentState, runtime: Runtime
|
||||||
|
) -> ModelRequest:
|
||||||
|
# Use a different model based on user subscription tier
|
||||||
|
if runtime.context.get("subscription_tier") == "premium":
|
||||||
|
request.model = "gpt-4o"
|
||||||
|
else:
|
||||||
|
request.model = "gpt-4o-mini"
|
||||||
|
|
||||||
|
return request
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
func: _ModelRequestSignature[StateT, ContextT],
|
||||||
|
) -> AgentMiddleware[StateT, ContextT]:
|
||||||
|
if is_callable_with_runtime_and_request(func):
|
||||||
|
|
||||||
|
def wrapped_with_runtime(
|
||||||
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||||
|
request: ModelRequest,
|
||||||
|
state: StateT,
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> ModelRequest:
|
||||||
|
return func(request, state, runtime)
|
||||||
|
|
||||||
|
wrapped = wrapped_with_runtime
|
||||||
|
else:
|
||||||
|
|
||||||
|
def wrapped_without_runtime(
|
||||||
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||||
|
request: ModelRequest,
|
||||||
|
state: StateT,
|
||||||
|
) -> ModelRequest:
|
||||||
|
return func(request, state) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Use function name as default if no name provided
|
||||||
|
middleware_name = name or cast(
|
||||||
|
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
||||||
|
)
|
||||||
|
|
||||||
|
return type(
|
||||||
|
middleware_name,
|
||||||
|
(AgentMiddleware,),
|
||||||
|
{
|
||||||
|
"state_schema": state_schema or AgentState,
|
||||||
|
"tools": tools or [],
|
||||||
|
"modify_model_request": wrapped,
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def after_model(
|
||||||
|
func: _NodeSignature[StateT, ContextT],
|
||||||
|
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def after_model(
|
||||||
|
func: None = None,
|
||||||
|
*,
|
||||||
|
state_schema: type[StateT] | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
jump_to: list[JumpTo] | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def after_model(
|
||||||
|
func: _NodeSignature[StateT, ContextT] | None = None,
|
||||||
|
*,
|
||||||
|
state_schema: type[StateT] | None = None,
|
||||||
|
tools: list[BaseTool] | None = None,
|
||||||
|
jump_to: list[JumpTo] | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> (
|
||||||
|
Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||||
|
| AgentMiddleware[StateT, ContextT]
|
||||||
|
):
|
||||||
|
"""Decorator used to dynamically create a middleware with the after_model hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to be decorated. Can accept either:
|
||||||
|
- `state: StateT` - Just the agent state (includes model response)
|
||||||
|
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
||||||
|
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||||
|
AgentState schema.
|
||||||
|
tools: Optional list of additional tools to register with this middleware.
|
||||||
|
jump_to: Optional list of valid jump destinations for conditional edges.
|
||||||
|
Valid values are: "tools", "model", "__end__"
|
||||||
|
name: Optional name for the generated middleware class. If not provided,
|
||||||
|
uses the decorated function's name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
||||||
|
that can be applied to a function.
|
||||||
|
|
||||||
|
The decorated function should return:
|
||||||
|
- `dict[str, Any]` - State updates to merge into the agent state
|
||||||
|
- `Command` - A command to control flow (e.g., jump to different node)
|
||||||
|
- `None` - No state updates or flow control
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Basic usage for logging model responses:
|
||||||
|
```python
|
||||||
|
@after_model
|
||||||
|
def log_latest_message(state: AgentState) -> None:
|
||||||
|
print(state["messages"][-1].content)
|
||||||
|
```
|
||||||
|
|
||||||
|
With custom state schema:
|
||||||
|
```python
|
||||||
|
@after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
|
||||||
|
def custom_after_model(state: MyCustomState) -> dict[str, Any]:
|
||||||
|
return {"custom_field": "updated_after_model"}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||||
|
if is_callable_with_runtime(func):
|
||||||
|
|
||||||
|
def wrapped_with_runtime(
|
||||||
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||||
|
state: StateT,
|
||||||
|
runtime: Runtime[ContextT],
|
||||||
|
) -> dict[str, Any] | Command | None:
|
||||||
|
return func(state, runtime)
|
||||||
|
|
||||||
|
wrapped = wrapped_with_runtime
|
||||||
|
else:
|
||||||
|
|
||||||
|
def wrapped_without_runtime(
|
||||||
|
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||||
|
state: StateT,
|
||||||
|
) -> dict[str, Any] | Command | None:
|
||||||
|
return func(state) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Use function name as default if no name provided
|
||||||
|
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
||||||
|
|
||||||
|
return type(
|
||||||
|
middleware_name,
|
||||||
|
(AgentMiddleware,),
|
||||||
|
{
|
||||||
|
"state_schema": state_schema or AgentState,
|
||||||
|
"tools": tools or [],
|
||||||
|
"after_model_jump_to": jump_to or [],
|
||||||
|
"after_model": wrapped,
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
|
if func is not None:
|
||||||
|
return decorator(func)
|
||||||
|
return decorator
|
||||||
|
@@ -464,7 +464,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
f"{middleware_w_after[0].__class__.__name__}.after_model",
|
f"{middleware_w_after[0].__class__.__name__}.after_model",
|
||||||
END,
|
END,
|
||||||
first_node,
|
first_node,
|
||||||
tools_available=tool_node is not None,
|
jump_to=middleware_w_after[0].after_model_jump_to,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add middleware edges (same as before)
|
# Add middleware edges (same as before)
|
||||||
@@ -475,7 +475,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
f"{m1.__class__.__name__}.before_model",
|
f"{m1.__class__.__name__}.before_model",
|
||||||
f"{m2.__class__.__name__}.before_model",
|
f"{m2.__class__.__name__}.before_model",
|
||||||
first_node,
|
first_node,
|
||||||
tools_available=tool_node is not None,
|
jump_to=m1.before_model_jump_to,
|
||||||
)
|
)
|
||||||
# Go directly to model_request after the last before_model
|
# Go directly to model_request after the last before_model
|
||||||
_add_middleware_edge(
|
_add_middleware_edge(
|
||||||
@@ -483,7 +483,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
f"{middleware_w_before[-1].__class__.__name__}.before_model",
|
f"{middleware_w_before[-1].__class__.__name__}.before_model",
|
||||||
"model_request",
|
"model_request",
|
||||||
first_node,
|
first_node,
|
||||||
tools_available=tool_node is not None,
|
jump_to=middleware_w_before[-1].before_model_jump_to,
|
||||||
)
|
)
|
||||||
|
|
||||||
if middleware_w_after:
|
if middleware_w_after:
|
||||||
@@ -496,7 +496,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
f"{m1.__class__.__name__}.after_model",
|
f"{m1.__class__.__name__}.after_model",
|
||||||
f"{m2.__class__.__name__}.after_model",
|
f"{m2.__class__.__name__}.after_model",
|
||||||
first_node,
|
first_node,
|
||||||
tools_available=tool_node is not None,
|
jump_to=m1.after_model_jump_to,
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
@@ -584,7 +584,7 @@ def _add_middleware_edge(
|
|||||||
name: str,
|
name: str,
|
||||||
default_destination: str,
|
default_destination: str,
|
||||||
model_destination: str,
|
model_destination: str,
|
||||||
tools_available: bool, # noqa: FBT001
|
jump_to: list[JumpTo] | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add an edge to the graph for a middleware node.
|
"""Add an edge to the graph for a middleware node.
|
||||||
|
|
||||||
@@ -594,18 +594,23 @@ def _add_middleware_edge(
|
|||||||
name: The name of the middleware node.
|
name: The name of the middleware node.
|
||||||
default_destination: The default destination for the edge.
|
default_destination: The default destination for the edge.
|
||||||
model_destination: The destination for the edge to the model.
|
model_destination: The destination for the edge to the model.
|
||||||
tools_available: Whether tools are available for the edge to potentially route to.
|
jump_to: The conditionally jumpable destinations for the edge.
|
||||||
"""
|
"""
|
||||||
|
if jump_to:
|
||||||
|
|
||||||
def jump_edge(state: AgentState) -> str:
|
def jump_edge(state: AgentState) -> str:
|
||||||
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
||||||
|
|
||||||
destinations = [default_destination]
|
destinations = [default_destination]
|
||||||
if default_destination != END:
|
|
||||||
destinations.append(END)
|
|
||||||
if tools_available:
|
|
||||||
destinations.append("tools")
|
|
||||||
if name != model_destination:
|
|
||||||
destinations.append(model_destination)
|
|
||||||
|
|
||||||
graph.add_conditional_edges(name, jump_edge, destinations)
|
if "__end__" in jump_to:
|
||||||
|
destinations.append(END)
|
||||||
|
if "tools" in jump_to:
|
||||||
|
destinations.append("tools")
|
||||||
|
if "model" in jump_to and name != model_destination:
|
||||||
|
destinations.append(model_destination)
|
||||||
|
|
||||||
|
graph.add_conditional_edges(name, jump_edge, destinations)
|
||||||
|
|
||||||
|
else:
|
||||||
|
graph.add_edge(name, default_destination)
|
||||||
|
@@ -30,8 +30,7 @@
|
|||||||
model_request(model_request)
|
model_request(model_request)
|
||||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopOne\2ebefore_model -.-> __end__;
|
NoopOne\2ebefore_model --> model_request;
|
||||||
NoopOne\2ebefore_model -.-> model_request;
|
|
||||||
__start__ --> NoopOne\2ebefore_model;
|
__start__ --> NoopOne\2ebefore_model;
|
||||||
model_request --> __end__;
|
model_request --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
@@ -53,12 +52,10 @@
|
|||||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||||
NoopTen\2eafter_model(NoopTen.after_model)
|
NoopTen\2eafter_model(NoopTen.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopTen\2eafter_model -.-> NoopTen\2ebefore_model;
|
NoopTen\2ebefore_model --> model_request;
|
||||||
NoopTen\2eafter_model -.-> __end__;
|
|
||||||
NoopTen\2ebefore_model -.-> __end__;
|
|
||||||
NoopTen\2ebefore_model -.-> model_request;
|
|
||||||
__start__ --> NoopTen\2ebefore_model;
|
__start__ --> NoopTen\2ebefore_model;
|
||||||
model_request --> NoopTen\2eafter_model;
|
model_request --> NoopTen\2eafter_model;
|
||||||
|
NoopTen\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -80,18 +77,12 @@
|
|||||||
NoopEleven\2ebefore_model(NoopEleven.before_model)
|
NoopEleven\2ebefore_model(NoopEleven.before_model)
|
||||||
NoopEleven\2eafter_model(NoopEleven.after_model)
|
NoopEleven\2eafter_model(NoopEleven.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEleven\2eafter_model -.-> NoopTen\2eafter_model;
|
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
|
||||||
NoopEleven\2eafter_model -.-> NoopTen\2ebefore_model;
|
NoopEleven\2ebefore_model --> model_request;
|
||||||
NoopEleven\2eafter_model -.-> __end__;
|
NoopTen\2ebefore_model --> NoopEleven\2ebefore_model;
|
||||||
NoopEleven\2ebefore_model -.-> NoopTen\2ebefore_model;
|
|
||||||
NoopEleven\2ebefore_model -.-> __end__;
|
|
||||||
NoopEleven\2ebefore_model -.-> model_request;
|
|
||||||
NoopTen\2eafter_model -.-> NoopTen\2ebefore_model;
|
|
||||||
NoopTen\2eafter_model -.-> __end__;
|
|
||||||
NoopTen\2ebefore_model -.-> NoopEleven\2ebefore_model;
|
|
||||||
NoopTen\2ebefore_model -.-> __end__;
|
|
||||||
__start__ --> NoopTen\2ebefore_model;
|
__start__ --> NoopTen\2ebefore_model;
|
||||||
model_request --> NoopEleven\2eafter_model;
|
model_request --> NoopEleven\2eafter_model;
|
||||||
|
NoopTen\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -111,11 +102,8 @@
|
|||||||
NoopOne\2ebefore_model(NoopOne.before_model)
|
NoopOne\2ebefore_model(NoopOne.before_model)
|
||||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopOne\2ebefore_model -.-> NoopTwo\2ebefore_model;
|
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||||
NoopOne\2ebefore_model -.-> __end__;
|
NoopTwo\2ebefore_model --> model_request;
|
||||||
NoopTwo\2ebefore_model -.-> NoopOne\2ebefore_model;
|
|
||||||
NoopTwo\2ebefore_model -.-> __end__;
|
|
||||||
NoopTwo\2ebefore_model -.-> model_request;
|
|
||||||
__start__ --> NoopOne\2ebefore_model;
|
__start__ --> NoopOne\2ebefore_model;
|
||||||
model_request --> __end__;
|
model_request --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
@@ -138,14 +126,9 @@
|
|||||||
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
NoopTwo\2ebefore_model(NoopTwo.before_model)
|
||||||
NoopThree\2ebefore_model(NoopThree.before_model)
|
NoopThree\2ebefore_model(NoopThree.before_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopOne\2ebefore_model -.-> NoopTwo\2ebefore_model;
|
NoopOne\2ebefore_model --> NoopTwo\2ebefore_model;
|
||||||
NoopOne\2ebefore_model -.-> __end__;
|
NoopThree\2ebefore_model --> model_request;
|
||||||
NoopThree\2ebefore_model -.-> NoopOne\2ebefore_model;
|
NoopTwo\2ebefore_model --> NoopThree\2ebefore_model;
|
||||||
NoopThree\2ebefore_model -.-> __end__;
|
|
||||||
NoopThree\2ebefore_model -.-> model_request;
|
|
||||||
NoopTwo\2ebefore_model -.-> NoopOne\2ebefore_model;
|
|
||||||
NoopTwo\2ebefore_model -.-> NoopThree\2ebefore_model;
|
|
||||||
NoopTwo\2ebefore_model -.-> __end__;
|
|
||||||
__start__ --> NoopOne\2ebefore_model;
|
__start__ --> NoopOne\2ebefore_model;
|
||||||
model_request --> __end__;
|
model_request --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
@@ -166,10 +149,9 @@
|
|||||||
model_request(model_request)
|
model_request(model_request)
|
||||||
NoopFour\2eafter_model(NoopFour.after_model)
|
NoopFour\2eafter_model(NoopFour.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopFour\2eafter_model -.-> __end__;
|
|
||||||
NoopFour\2eafter_model -.-> model_request;
|
|
||||||
__start__ --> model_request;
|
__start__ --> model_request;
|
||||||
model_request --> NoopFour\2eafter_model;
|
model_request --> NoopFour\2eafter_model;
|
||||||
|
NoopFour\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -189,13 +171,10 @@
|
|||||||
NoopFour\2eafter_model(NoopFour.after_model)
|
NoopFour\2eafter_model(NoopFour.after_model)
|
||||||
NoopFive\2eafter_model(NoopFive.after_model)
|
NoopFive\2eafter_model(NoopFive.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopFive\2eafter_model -.-> NoopFour\2eafter_model;
|
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||||
NoopFive\2eafter_model -.-> __end__;
|
|
||||||
NoopFive\2eafter_model -.-> model_request;
|
|
||||||
NoopFour\2eafter_model -.-> __end__;
|
|
||||||
NoopFour\2eafter_model -.-> model_request;
|
|
||||||
__start__ --> model_request;
|
__start__ --> model_request;
|
||||||
model_request --> NoopFive\2eafter_model;
|
model_request --> NoopFive\2eafter_model;
|
||||||
|
NoopFour\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -216,16 +195,11 @@
|
|||||||
NoopFive\2eafter_model(NoopFive.after_model)
|
NoopFive\2eafter_model(NoopFive.after_model)
|
||||||
NoopSix\2eafter_model(NoopSix.after_model)
|
NoopSix\2eafter_model(NoopSix.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopFive\2eafter_model -.-> NoopFour\2eafter_model;
|
NoopFive\2eafter_model --> NoopFour\2eafter_model;
|
||||||
NoopFive\2eafter_model -.-> __end__;
|
NoopSix\2eafter_model --> NoopFive\2eafter_model;
|
||||||
NoopFive\2eafter_model -.-> model_request;
|
|
||||||
NoopFour\2eafter_model -.-> __end__;
|
|
||||||
NoopFour\2eafter_model -.-> model_request;
|
|
||||||
NoopSix\2eafter_model -.-> NoopFive\2eafter_model;
|
|
||||||
NoopSix\2eafter_model -.-> __end__;
|
|
||||||
NoopSix\2eafter_model -.-> model_request;
|
|
||||||
__start__ --> model_request;
|
__start__ --> model_request;
|
||||||
model_request --> NoopSix\2eafter_model;
|
model_request --> NoopSix\2eafter_model;
|
||||||
|
NoopFour\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -245,12 +219,10 @@
|
|||||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopSeven\2ebefore_model --> model_request;
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> model_request;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopSeven\2eafter_model;
|
model_request --> NoopSeven\2eafter_model;
|
||||||
|
NoopSeven\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -272,18 +244,12 @@
|
|||||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||||
NoopEight\2eafter_model(NoopEight.after_model)
|
NoopEight\2eafter_model(NoopEight.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model;
|
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopEight\2ebefore_model --> model_request;
|
||||||
NoopEight\2eafter_model -.-> __end__;
|
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||||
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2ebefore_model -.-> __end__;
|
|
||||||
NoopEight\2ebefore_model -.-> model_request;
|
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
|
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopEight\2eafter_model;
|
model_request --> NoopEight\2eafter_model;
|
||||||
|
NoopSeven\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -307,24 +273,14 @@
|
|||||||
NoopNine\2ebefore_model(NoopNine.before_model)
|
NoopNine\2ebefore_model(NoopNine.before_model)
|
||||||
NoopNine\2eafter_model(NoopNine.after_model)
|
NoopNine\2eafter_model(NoopNine.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model;
|
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopEight\2ebefore_model --> NoopNine\2ebefore_model;
|
||||||
NoopEight\2eafter_model -.-> __end__;
|
NoopNine\2eafter_model --> NoopEight\2eafter_model;
|
||||||
NoopEight\2ebefore_model -.-> NoopNine\2ebefore_model;
|
NoopNine\2ebefore_model --> model_request;
|
||||||
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||||
NoopEight\2ebefore_model -.-> __end__;
|
|
||||||
NoopNine\2eafter_model -.-> NoopEight\2eafter_model;
|
|
||||||
NoopNine\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopNine\2eafter_model -.-> __end__;
|
|
||||||
NoopNine\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopNine\2ebefore_model -.-> __end__;
|
|
||||||
NoopNine\2ebefore_model -.-> model_request;
|
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
|
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopNine\2eafter_model;
|
model_request --> NoopNine\2eafter_model;
|
||||||
|
NoopSeven\2eafter_model --> __end__;
|
||||||
classDef default fill:#f2f0ff,line-height:1.2
|
classDef default fill:#f2f0ff,line-height:1.2
|
||||||
classDef first fill-opacity:0
|
classDef first fill-opacity:0
|
||||||
classDef last fill:#bfb6fc
|
classDef last fill:#bfb6fc
|
||||||
@@ -347,20 +303,13 @@
|
|||||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||||
NoopEight\2eafter_model(NoopEight.after_model)
|
NoopEight\2eafter_model(NoopEight.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model;
|
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2eafter_model -.-> __end__;
|
|
||||||
NoopEight\2eafter_model -.-> tools;
|
|
||||||
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2ebefore_model -.-> __end__;
|
NoopEight\2ebefore_model -.-> __end__;
|
||||||
NoopEight\2ebefore_model -.-> model_request;
|
NoopEight\2ebefore_model -.-> model_request;
|
||||||
NoopEight\2ebefore_model -.-> tools;
|
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
NoopSeven\2eafter_model -.-> __end__;
|
||||||
NoopSeven\2eafter_model -.-> tools;
|
NoopSeven\2eafter_model -.-> tools;
|
||||||
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
|
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> tools;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopEight\2eafter_model;
|
model_request --> NoopEight\2eafter_model;
|
||||||
tools -.-> NoopSeven\2ebefore_model;
|
tools -.-> NoopSeven\2ebefore_model;
|
||||||
@@ -387,20 +336,13 @@
|
|||||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||||
NoopEight\2eafter_model(NoopEight.after_model)
|
NoopEight\2eafter_model(NoopEight.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model;
|
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2eafter_model -.-> __end__;
|
|
||||||
NoopEight\2eafter_model -.-> tools;
|
|
||||||
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2ebefore_model -.-> __end__;
|
NoopEight\2ebefore_model -.-> __end__;
|
||||||
NoopEight\2ebefore_model -.-> model_request;
|
NoopEight\2ebefore_model -.-> model_request;
|
||||||
NoopEight\2ebefore_model -.-> tools;
|
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
NoopSeven\2eafter_model -.-> __end__;
|
||||||
NoopSeven\2eafter_model -.-> tools;
|
NoopSeven\2eafter_model -.-> tools;
|
||||||
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
|
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> tools;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopEight\2eafter_model;
|
model_request --> NoopEight\2eafter_model;
|
||||||
tools -.-> NoopSeven\2ebefore_model;
|
tools -.-> NoopSeven\2ebefore_model;
|
||||||
@@ -427,20 +369,13 @@
|
|||||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||||
NoopEight\2eafter_model(NoopEight.after_model)
|
NoopEight\2eafter_model(NoopEight.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model;
|
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2eafter_model -.-> __end__;
|
|
||||||
NoopEight\2eafter_model -.-> tools;
|
|
||||||
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2ebefore_model -.-> __end__;
|
NoopEight\2ebefore_model -.-> __end__;
|
||||||
NoopEight\2ebefore_model -.-> model_request;
|
NoopEight\2ebefore_model -.-> model_request;
|
||||||
NoopEight\2ebefore_model -.-> tools;
|
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
NoopSeven\2eafter_model -.-> __end__;
|
||||||
NoopSeven\2eafter_model -.-> tools;
|
NoopSeven\2eafter_model -.-> tools;
|
||||||
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
|
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> tools;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopEight\2eafter_model;
|
model_request --> NoopEight\2eafter_model;
|
||||||
tools -.-> NoopSeven\2ebefore_model;
|
tools -.-> NoopSeven\2ebefore_model;
|
||||||
@@ -467,20 +402,13 @@
|
|||||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||||
NoopEight\2eafter_model(NoopEight.after_model)
|
NoopEight\2eafter_model(NoopEight.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model;
|
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2eafter_model -.-> __end__;
|
|
||||||
NoopEight\2eafter_model -.-> tools;
|
|
||||||
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2ebefore_model -.-> __end__;
|
NoopEight\2ebefore_model -.-> __end__;
|
||||||
NoopEight\2ebefore_model -.-> model_request;
|
NoopEight\2ebefore_model -.-> model_request;
|
||||||
NoopEight\2ebefore_model -.-> tools;
|
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
NoopSeven\2eafter_model -.-> __end__;
|
||||||
NoopSeven\2eafter_model -.-> tools;
|
NoopSeven\2eafter_model -.-> tools;
|
||||||
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
|
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> tools;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopEight\2eafter_model;
|
model_request --> NoopEight\2eafter_model;
|
||||||
tools -.-> NoopSeven\2ebefore_model;
|
tools -.-> NoopSeven\2ebefore_model;
|
||||||
@@ -507,20 +435,13 @@
|
|||||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||||
NoopEight\2eafter_model(NoopEight.after_model)
|
NoopEight\2eafter_model(NoopEight.after_model)
|
||||||
__end__([<p>__end__</p>]):::last
|
__end__([<p>__end__</p>]):::last
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2eafter_model;
|
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||||
NoopEight\2eafter_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2eafter_model -.-> __end__;
|
|
||||||
NoopEight\2eafter_model -.-> tools;
|
|
||||||
NoopEight\2ebefore_model -.-> NoopSeven\2ebefore_model;
|
|
||||||
NoopEight\2ebefore_model -.-> __end__;
|
NoopEight\2ebefore_model -.-> __end__;
|
||||||
NoopEight\2ebefore_model -.-> model_request;
|
NoopEight\2ebefore_model -.-> model_request;
|
||||||
NoopEight\2ebefore_model -.-> tools;
|
|
||||||
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
NoopSeven\2eafter_model -.-> NoopSeven\2ebefore_model;
|
||||||
NoopSeven\2eafter_model -.-> __end__;
|
NoopSeven\2eafter_model -.-> __end__;
|
||||||
NoopSeven\2eafter_model -.-> tools;
|
NoopSeven\2eafter_model -.-> tools;
|
||||||
NoopSeven\2ebefore_model -.-> NoopEight\2ebefore_model;
|
NoopSeven\2ebefore_model --> NoopEight\2ebefore_model;
|
||||||
NoopSeven\2ebefore_model -.-> __end__;
|
|
||||||
NoopSeven\2ebefore_model -.-> tools;
|
|
||||||
__start__ --> NoopSeven\2ebefore_model;
|
__start__ --> NoopSeven\2ebefore_model;
|
||||||
model_request --> NoopEight\2eafter_model;
|
model_request --> NoopEight\2eafter_model;
|
||||||
tools -.-> NoopSeven\2ebefore_model;
|
tools -.-> NoopSeven\2ebefore_model;
|
||||||
|
@@ -34,7 +34,6 @@ from langchain.agents.middleware.types import (
|
|||||||
OmitFromOutput,
|
OmitFromOutput,
|
||||||
PrivateStateAttr,
|
PrivateStateAttr,
|
||||||
)
|
)
|
||||||
from langchain.agents.middleware.dynamic_system_prompt import DynamicSystemPromptMiddleware
|
|
||||||
|
|
||||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
@@ -332,6 +331,8 @@ def test_create_agent_jump(
|
|||||||
calls.append("NoopSeven.after_model")
|
calls.append("NoopSeven.after_model")
|
||||||
|
|
||||||
class NoopEight(AgentMiddleware):
|
class NoopEight(AgentMiddleware):
|
||||||
|
before_model_jump_to = [END]
|
||||||
|
|
||||||
def before_model(self, state) -> dict[str, Any]:
|
def before_model(self, state) -> dict[str, Any]:
|
||||||
calls.append("NoopEight.before_model")
|
calls.append("NoopEight.before_model")
|
||||||
return {"jump_to": END}
|
return {"jump_to": END}
|
||||||
@@ -1221,74 +1222,6 @@ def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
|
|||||||
assert result["response"].condition == "sunny"
|
assert result["response"].condition == "sunny"
|
||||||
|
|
||||||
|
|
||||||
# Tests for DynamicSystemPromptMiddleware
|
|
||||||
def test_dynamic_system_prompt_middleware_basic() -> None:
|
|
||||||
"""Test basic functionality of DynamicSystemPromptMiddleware."""
|
|
||||||
|
|
||||||
def dynamic_system_prompt(state: AgentState) -> str:
|
|
||||||
messages = state.get("messages", [])
|
|
||||||
if messages:
|
|
||||||
return f"You are a helpful assistant. Message count: {len(messages)}"
|
|
||||||
return "You are a helpful assistant. No messages yet."
|
|
||||||
|
|
||||||
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
|
|
||||||
|
|
||||||
# Test with empty state
|
|
||||||
empty_state = {"messages": []}
|
|
||||||
request = ModelRequest(
|
|
||||||
model=FakeToolCallingModel(),
|
|
||||||
system_prompt="Original prompt",
|
|
||||||
messages=[],
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
modified_request = middleware.modify_model_request(request, empty_state, None)
|
|
||||||
assert modified_request.system_prompt == "You are a helpful assistant. No messages yet."
|
|
||||||
|
|
||||||
state_with_messages = {"messages": [HumanMessage("Hello"), AIMessage("Hi")]}
|
|
||||||
modified_request = middleware.modify_model_request(request, state_with_messages, None)
|
|
||||||
assert modified_request.system_prompt == "You are a helpful assistant. Message count: 2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_system_prompt_middleware_with_context() -> None:
|
|
||||||
"""Test DynamicSystemPromptMiddleware with runtime context."""
|
|
||||||
|
|
||||||
class MockContext(TypedDict):
|
|
||||||
user_role: str
|
|
||||||
|
|
||||||
def dynamic_system_prompt(state: AgentState, runtime: Runtime[MockContext]) -> str:
|
|
||||||
base_prompt = "You are a helpful assistant."
|
|
||||||
if runtime and hasattr(runtime, "context"):
|
|
||||||
user_role = runtime.context.get("user_role", "user")
|
|
||||||
return f"{base_prompt} User role: {user_role}"
|
|
||||||
return base_prompt
|
|
||||||
|
|
||||||
middleware = DynamicSystemPromptMiddleware(dynamic_system_prompt)
|
|
||||||
|
|
||||||
# Create a mock runtime with context
|
|
||||||
class MockRuntime:
|
|
||||||
def __init__(self, context):
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
mock_runtime = MockRuntime(context={"user_role": "admin"})
|
|
||||||
|
|
||||||
request = ModelRequest(
|
|
||||||
model=FakeToolCallingModel(),
|
|
||||||
system_prompt="Original prompt",
|
|
||||||
messages=[HumanMessage("Test")],
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
state = {"messages": [HumanMessage("Test")]}
|
|
||||||
modified_request = middleware.modify_model_request(request, state, mock_runtime)
|
|
||||||
|
|
||||||
assert modified_request.system_prompt == "You are a helpful assistant. User role: admin"
|
|
||||||
|
|
||||||
|
|
||||||
def test_public_private_state_for_custom_middleware() -> None:
|
def test_public_private_state_for_custom_middleware() -> None:
|
||||||
"""Test public and private state for custom middleware."""
|
"""Test public and private state for custom middleware."""
|
||||||
|
|
||||||
|
@@ -0,0 +1,152 @@
|
|||||||
|
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.types import Command
|
||||||
|
|
||||||
|
from langchain.agents.middleware.types import (
|
||||||
|
AgentMiddleware,
|
||||||
|
AgentState,
|
||||||
|
ModelRequest,
|
||||||
|
before_model,
|
||||||
|
after_model,
|
||||||
|
modify_model_request,
|
||||||
|
)
|
||||||
|
from langchain.agents.middleware_agent import create_agent
|
||||||
|
from .model import FakeToolCallingModel
|
||||||
|
|
||||||
|
|
||||||
|
class CustomState(AgentState):
|
||||||
|
"""Custom state schema for testing."""
|
||||||
|
|
||||||
|
custom_field: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def test_tool(input: str) -> str:
|
||||||
|
"""A test tool for middleware testing."""
|
||||||
|
return f"Tool result: {input}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_before_model_decorator() -> None:
|
||||||
|
"""Test before_model decorator with all configuration options."""
|
||||||
|
|
||||||
|
@before_model(
|
||||||
|
state_schema=CustomState, tools=[test_tool], jump_to=["__end__"], name="CustomBeforeModel"
|
||||||
|
)
|
||||||
|
def custom_before_model(state: CustomState) -> dict[str, Any]:
|
||||||
|
return {"jump_to": "__end__"}
|
||||||
|
|
||||||
|
assert isinstance(custom_before_model, AgentMiddleware)
|
||||||
|
assert custom_before_model.state_schema == CustomState
|
||||||
|
assert custom_before_model.tools == [test_tool]
|
||||||
|
assert custom_before_model.before_model_jump_to == ["__end__"]
|
||||||
|
assert custom_before_model.__class__.__name__ == "CustomBeforeModel"
|
||||||
|
|
||||||
|
result = custom_before_model.before_model({"messages": [HumanMessage("Hello")]})
|
||||||
|
assert result == {"jump_to": "__end__"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_after_model_decorator() -> None:
|
||||||
|
"""Test after_model decorator with all configuration options."""
|
||||||
|
|
||||||
|
@after_model(
|
||||||
|
state_schema=CustomState,
|
||||||
|
tools=[test_tool],
|
||||||
|
jump_to=["model", "__end__"],
|
||||||
|
name="CustomAfterModel",
|
||||||
|
)
|
||||||
|
def custom_after_model(state: CustomState) -> dict[str, Any]:
|
||||||
|
return {"jump_to": "model"}
|
||||||
|
|
||||||
|
# Verify all options were applied
|
||||||
|
assert isinstance(custom_after_model, AgentMiddleware)
|
||||||
|
assert custom_after_model.state_schema == CustomState
|
||||||
|
assert custom_after_model.tools == [test_tool]
|
||||||
|
assert custom_after_model.after_model_jump_to == ["model", "__end__"]
|
||||||
|
assert custom_after_model.__class__.__name__ == "CustomAfterModel"
|
||||||
|
|
||||||
|
# Verify it works
|
||||||
|
result = custom_after_model.after_model({"messages": [HumanMessage("Hello"), AIMessage("Hi!")]})
|
||||||
|
assert result == {"jump_to": "model"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_model_request_decorator() -> None:
|
||||||
|
"""Test modify_model_request decorator with all configuration options."""
|
||||||
|
|
||||||
|
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="CustomModifyRequest")
|
||||||
|
def custom_modify_request(request: ModelRequest, state: CustomState) -> ModelRequest:
|
||||||
|
request.system_prompt = "Modified"
|
||||||
|
return request
|
||||||
|
|
||||||
|
# Verify all options were applied
|
||||||
|
assert isinstance(custom_modify_request, AgentMiddleware)
|
||||||
|
assert custom_modify_request.state_schema == CustomState
|
||||||
|
assert custom_modify_request.tools == [test_tool]
|
||||||
|
assert custom_modify_request.__class__.__name__ == "CustomModifyRequest"
|
||||||
|
|
||||||
|
# Verify it works
|
||||||
|
original_request = ModelRequest(
|
||||||
|
model="test-model",
|
||||||
|
system_prompt="Original",
|
||||||
|
messages=[HumanMessage("Hello")],
|
||||||
|
tool_choice=None,
|
||||||
|
tools=[],
|
||||||
|
response_format=None,
|
||||||
|
)
|
||||||
|
result = custom_modify_request.modify_model_request(
|
||||||
|
original_request, {"messages": [HumanMessage("Hello")]}
|
||||||
|
)
|
||||||
|
assert result.system_prompt == "Modified"
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_decorators_integration() -> None:
|
||||||
|
"""Test all three decorators working together in an agent."""
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
@before_model
|
||||||
|
def track_before(state: AgentState) -> None:
|
||||||
|
call_order.append("before")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@modify_model_request
|
||||||
|
def track_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||||
|
call_order.append("modify")
|
||||||
|
return request
|
||||||
|
|
||||||
|
@after_model
|
||||||
|
def track_after(state: AgentState) -> None:
|
||||||
|
call_order.append("after")
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent = create_agent(
|
||||||
|
model=FakeToolCallingModel(), middleware=[track_before, track_modify, track_after]
|
||||||
|
)
|
||||||
|
agent = agent.compile()
|
||||||
|
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||||
|
|
||||||
|
assert call_order == ["before", "modify", "after"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_decorators_use_function_names_as_default() -> None:
|
||||||
|
"""Test that decorators use function names as default middleware names."""
|
||||||
|
|
||||||
|
@before_model
|
||||||
|
def my_before_hook(state: AgentState) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@modify_model_request
|
||||||
|
def my_modify_hook(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||||
|
return request
|
||||||
|
|
||||||
|
@after_model
|
||||||
|
def my_after_hook(state: AgentState) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Verify that function names are used as middleware class names
|
||||||
|
assert my_before_hook.__class__.__name__ == "my_before_hook"
|
||||||
|
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
|
||||||
|
assert my_after_hook.__class__.__name__ == "my_after_hook"
|
Reference in New Issue
Block a user