mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-16 01:59:52 +00:00
Compare commits
4 Commits
sr/refacto
...
sr/rename-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fbde58ead | ||
|
|
c154cdca5e | ||
|
|
530cae3f8c | ||
|
|
b1c0b32e1b |
@@ -45,7 +45,7 @@ __all__ = [
|
||||
"PublicAgentState",
|
||||
]
|
||||
|
||||
JumpTo = Literal["tools", "model", "__end__"]
|
||||
JumpTo = Literal["tools", "model", "end"]
|
||||
"""Destination to jump to when a middleware node returns."""
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
@@ -237,7 +237,7 @@ def before_model(
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
jump_to: Optional list of valid jump destinations for conditional edges.
|
||||
Valid values are: "tools", "model", "__end__"
|
||||
Valid values are: "tools", "model", "end"
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
@@ -260,10 +260,10 @@ def before_model(
|
||||
|
||||
Advanced usage with runtime and conditional jumping:
|
||||
```python
|
||||
@before_model(jump_to=["__end__"])
|
||||
@before_model(jump_to=["end"])
|
||||
def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if some_condition(state):
|
||||
return {"jump_to": "__end__"}
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
```
|
||||
|
||||
@@ -474,7 +474,7 @@ def after_model(
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
jump_to: Optional list of valid jump destinations for conditional edges.
|
||||
Valid values are: "tools", "model", "__end__"
|
||||
Valid values are: "tools", "model", "end"
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
|
||||
@@ -505,8 +505,10 @@ def create_agent( # noqa: PLR0915
|
||||
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
||||
if jump_to == "model":
|
||||
return first_node
|
||||
if jump_to:
|
||||
return jump_to
|
||||
if jump_to == "end":
|
||||
return "__end__"
|
||||
if jump_to == "tools":
|
||||
return "tools"
|
||||
return None
|
||||
|
||||
|
||||
@@ -603,7 +605,7 @@ def _add_middleware_edge(
|
||||
|
||||
destinations = [default_destination]
|
||||
|
||||
if "__end__" in jump_to:
|
||||
if "end" in jump_to:
|
||||
destinations.append(END)
|
||||
if "tools" in jump_to:
|
||||
destinations.append("tools")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -331,11 +331,11 @@ def test_create_agent_jump(
|
||||
calls.append("NoopSeven.after_model")
|
||||
|
||||
class NoopEight(AgentMiddleware):
|
||||
before_model_jump_to = [END]
|
||||
before_model_jump_to = ["end"]
|
||||
|
||||
def before_model(self, state) -> dict[str, Any]:
|
||||
calls.append("NoopEight.before_model")
|
||||
return {"jump_to": END}
|
||||
return {"jump_to": "end"}
|
||||
|
||||
def modify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("NoopEight.modify_model_request")
|
||||
|
||||
@@ -35,19 +35,19 @@ def test_before_model_decorator() -> None:
|
||||
"""Test before_model decorator with all configuration options."""
|
||||
|
||||
@before_model(
|
||||
state_schema=CustomState, tools=[test_tool], jump_to=["__end__"], name="CustomBeforeModel"
|
||||
state_schema=CustomState, tools=[test_tool], jump_to=["end"], name="CustomBeforeModel"
|
||||
)
|
||||
def custom_before_model(state: CustomState) -> dict[str, Any]:
|
||||
return {"jump_to": "__end__"}
|
||||
return {"jump_to": "end"}
|
||||
|
||||
assert isinstance(custom_before_model, AgentMiddleware)
|
||||
assert custom_before_model.state_schema == CustomState
|
||||
assert custom_before_model.tools == [test_tool]
|
||||
assert custom_before_model.before_model_jump_to == ["__end__"]
|
||||
assert custom_before_model.before_model_jump_to == ["end"]
|
||||
assert custom_before_model.__class__.__name__ == "CustomBeforeModel"
|
||||
|
||||
result = custom_before_model.before_model({"messages": [HumanMessage("Hello")]})
|
||||
assert result == {"jump_to": "__end__"}
|
||||
assert result == {"jump_to": "end"}
|
||||
|
||||
|
||||
def test_after_model_decorator() -> None:
|
||||
@@ -56,7 +56,7 @@ def test_after_model_decorator() -> None:
|
||||
@after_model(
|
||||
state_schema=CustomState,
|
||||
tools=[test_tool],
|
||||
jump_to=["model", "__end__"],
|
||||
jump_to=["model", "end"],
|
||||
name="CustomAfterModel",
|
||||
)
|
||||
def custom_after_model(state: CustomState) -> dict[str, Any]:
|
||||
@@ -66,7 +66,7 @@ def test_after_model_decorator() -> None:
|
||||
assert isinstance(custom_after_model, AgentMiddleware)
|
||||
assert custom_after_model.state_schema == CustomState
|
||||
assert custom_after_model.tools == [test_tool]
|
||||
assert custom_after_model.after_model_jump_to == ["model", "__end__"]
|
||||
assert custom_after_model.after_model_jump_to == ["model", "end"]
|
||||
assert custom_after_model.__class__.__name__ == "CustomAfterModel"
|
||||
|
||||
# Verify it works
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
"""Test matrix for middleware graph structures with all combinations of jump targets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, JumpTo
|
||||
from langchain_core.tools import tool, BaseTool
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
# All possible combinations of jump targets (power set of ["tools", "model", "end"])
|
||||
JUMP_TARGETS: list[list[JumpTo]] = [
|
||||
[],
|
||||
["tools"],
|
||||
["model"],
|
||||
["end"],
|
||||
["tools", "model"],
|
||||
["tools", "end"],
|
||||
["model", "end"],
|
||||
["tools", "model", "end"],
|
||||
]
|
||||
|
||||
|
||||
def create_middleware_with_jump_to(
|
||||
name: str,
|
||||
before_model_jump_to_: list[JumpTo],
|
||||
after_model_jump_to_: list[JumpTo],
|
||||
tools_: list[BaseTool] = [],
|
||||
) -> AgentMiddleware:
|
||||
"""Create a middleware class with specified jump_to configurations."""
|
||||
|
||||
class CustomMiddleware(AgentMiddleware):
|
||||
before_model_jump_to: list[JumpTo] = before_model_jump_to_
|
||||
after_model_jump_to: list[JumpTo] = after_model_jump_to_
|
||||
tools: list[BaseTool] = tools_
|
||||
|
||||
def before_model(self, state: Any, runtime: Any) -> None:
|
||||
"""Before model hook."""
|
||||
pass
|
||||
|
||||
def after_model(self, state: Any, runtime: Any) -> None:
|
||||
"""After model hook."""
|
||||
pass
|
||||
|
||||
CustomMiddleware.__name__ = name
|
||||
return CustomMiddleware()
|
||||
|
||||
|
||||
@tool
|
||||
def some_tool() -> str:
|
||||
"""A simple test tool."""
|
||||
return "Hello, world!"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TestCase:
|
||||
"""Test case configuration for middleware graph testing.
|
||||
|
||||
Represents a specific combination of middleware jump targets and tool availability.
|
||||
|
||||
Example:
|
||||
TestCase(
|
||||
a_before=["tools"],
|
||||
a_after=["end"],
|
||||
b_before=[],
|
||||
b_after=["model"],
|
||||
has_tools=True
|
||||
)
|
||||
This creates two middleware instances where:
|
||||
- MiddlewareA jumps to "tools" before model, "end" after model
|
||||
- MiddlewareB has no jumps before model, jumps to "model" after model
|
||||
- Agent has tools available
|
||||
"""
|
||||
|
||||
a_before: list[JumpTo]
|
||||
a_after: list[JumpTo]
|
||||
b_before: list[JumpTo]
|
||||
b_after: list[JumpTo]
|
||||
has_tools: bool
|
||||
|
||||
|
||||
def format_jumps(jumps: list[JumpTo]) -> str:
|
||||
"""Format jump targets for test ID."""
|
||||
return "_".join(jumps) if jumps else "empty"
|
||||
|
||||
|
||||
def format_test_case_name(test_case: TestCase) -> str:
|
||||
"""Format the test case name for pytest ID."""
|
||||
return (
|
||||
f"A_before_{format_jumps(test_case.a_before)}_"
|
||||
f"A_after_{format_jumps(test_case.a_after)}_"
|
||||
f"B_before_{format_jumps(test_case.b_before)}_"
|
||||
f"B_after_{format_jumps(test_case.b_after)}_"
|
||||
f"tools_{test_case.has_tools}"
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_test_case(
|
||||
a_before: list[JumpTo],
|
||||
a_after: list[JumpTo],
|
||||
b_before: list[JumpTo],
|
||||
b_after: list[JumpTo],
|
||||
has_tools: bool,
|
||||
) -> bool:
|
||||
"""Check if test case is valid - can't jump to tools when no tools available."""
|
||||
if has_tools:
|
||||
return True
|
||||
|
||||
all_jump_targets = {*a_before, *a_after, *b_before, *b_after}
|
||||
return "tools" not in all_jump_targets
|
||||
|
||||
|
||||
def generate_test_cases() -> list[TestCase]:
|
||||
"""Generate all valid test case combinations."""
|
||||
test_cases: list[TestCase] = []
|
||||
|
||||
for has_tools in [False, True]:
|
||||
for a_before, a_after, b_before, b_after in itertools.product(JUMP_TARGETS, repeat=4):
|
||||
if _is_valid_test_case(a_before, a_after, b_before, b_after, has_tools):
|
||||
test_cases.append(
|
||||
TestCase(
|
||||
a_before=a_before,
|
||||
a_after=a_after,
|
||||
b_before=b_before,
|
||||
b_after=b_after,
|
||||
has_tools=has_tools,
|
||||
)
|
||||
)
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
generate_test_cases(),
|
||||
ids=format_test_case_name,
|
||||
)
|
||||
def test_middleware_graph_structure(
|
||||
snapshot: SnapshotAssertion,
|
||||
test_case: TestCase,
|
||||
) -> None:
|
||||
"""Test that middleware graphs are created with correct structure for all combinations."""
|
||||
middleware_a = create_middleware_with_jump_to(
|
||||
"MiddlewareA", test_case.a_before, test_case.a_after
|
||||
)
|
||||
middleware_b = create_middleware_with_jump_to(
|
||||
"MiddlewareB", test_case.b_before, test_case.b_after
|
||||
)
|
||||
|
||||
tools = [some_tool] if test_case.has_tools else []
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=tools,
|
||||
middleware=[middleware_a, middleware_b],
|
||||
)
|
||||
|
||||
mermaid_diagram = agent.compile().get_graph().draw_mermaid(with_styles=False)
|
||||
assert mermaid_diagram == snapshot
|
||||
|
||||
|
||||
def test_tool_registration_with_middleware(
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test tool registration via middleware, agent, or both."""
|
||||
# Test case 1: No tools anywhere
|
||||
middleware_no_tools = create_middleware_with_jump_to(
|
||||
"Middleware", before_model_jump_to_=[], after_model_jump_to_=[], tools_=[]
|
||||
)
|
||||
agent_no_tools = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
middleware=[middleware_no_tools],
|
||||
)
|
||||
|
||||
diagram_no_tools = agent_no_tools.compile().get_graph().draw_mermaid(with_styles=False)
|
||||
assert diagram_no_tools == snapshot(name="no_tools")
|
||||
|
||||
# Test case 2: Tools only via middleware
|
||||
middleware_with_tools = create_middleware_with_jump_to(
|
||||
"Middleware", before_model_jump_to_=[], after_model_jump_to_=[], tools_=[some_tool]
|
||||
)
|
||||
agent_middleware_tools = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
middleware=[middleware_with_tools],
|
||||
)
|
||||
|
||||
diagram_middleware_tools = (
|
||||
agent_middleware_tools.compile().get_graph().draw_mermaid(with_styles=False)
|
||||
)
|
||||
assert diagram_middleware_tools == snapshot(name="middleware_tools_only")
|
||||
|
||||
# Test case 3: Tools only via agent
|
||||
middleware_no_tools_agent = create_middleware_with_jump_to(
|
||||
"Middleware", before_model_jump_to_=[], after_model_jump_to_=[], tools_=[]
|
||||
)
|
||||
agent_with_tools = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[some_tool],
|
||||
middleware=[middleware_no_tools_agent],
|
||||
)
|
||||
|
||||
diagram_agent_tools = agent_with_tools.compile().get_graph().draw_mermaid(with_styles=False)
|
||||
assert diagram_agent_tools == snapshot(name="agent_tools_only")
|
||||
|
||||
# Test case 4: Tools via both middleware and agent
|
||||
agent_both_tools = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[some_tool],
|
||||
middleware=[middleware_with_tools],
|
||||
)
|
||||
|
||||
diagram_both_tools = agent_both_tools.compile().get_graph().draw_mermaid(with_styles=False)
|
||||
assert diagram_both_tools == snapshot(name="both_tools")
|
||||
Reference in New Issue
Block a user