mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-24 04:36:46 +00:00
fix(langchain_v1): fix edges when there's no middleware (#33321)
1. Main fix: when we don't have a response format or middleware, don't draw a conditional edge back to the loop entrypoint (self loop on model) 2. Supplementary fix: when we jump to `end` and there is an `after_agent` hook, jump there instead of `__end__` Other improvements -- I can remove these if they're more harmful than helpful 1. Use keyword only arguments for edge generator functions for clarity 2. Rename args to `model_destination` and `end_destination` for clarity
This commit is contained in:
@@ -842,22 +842,41 @@ def create_agent( # noqa: PLR0915
|
||||
graph.add_conditional_edges(
|
||||
"tools",
|
||||
_make_tools_to_model_edge(
|
||||
tool_node, loop_entry_node, structured_output_tools, exit_node
|
||||
tool_node=tool_node,
|
||||
model_destination=loop_entry_node,
|
||||
structured_output_tools=structured_output_tools,
|
||||
end_destination=exit_node,
|
||||
),
|
||||
[loop_entry_node, exit_node],
|
||||
)
|
||||
|
||||
# base destinations are tools and exit_node
|
||||
# we add the loop_entry node to edge destinations if:
|
||||
# - there is an after model hook(s) -- allows jump_to to model
|
||||
# potentially artificially injected tool messages, ex HITL
|
||||
# - there is a response format -- to allow for jumping to model to handle
|
||||
# regenerating structured output tool calls
|
||||
model_to_tools_destinations = ["tools", exit_node]
|
||||
if response_format or loop_exit_node != "model":
|
||||
model_to_tools_destinations.append(loop_entry_node)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
loop_exit_node,
|
||||
_make_model_to_tools_edge(
|
||||
loop_entry_node, structured_output_tools, tool_node, exit_node
|
||||
model_destination=loop_entry_node,
|
||||
structured_output_tools=structured_output_tools,
|
||||
tool_node=tool_node,
|
||||
end_destination=exit_node,
|
||||
),
|
||||
[loop_entry_node, "tools", exit_node],
|
||||
model_to_tools_destinations,
|
||||
)
|
||||
elif len(structured_output_tools) > 0:
|
||||
graph.add_conditional_edges(
|
||||
loop_exit_node,
|
||||
_make_model_to_model_edge(loop_entry_node, exit_node),
|
||||
_make_model_to_model_edge(
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
),
|
||||
[loop_entry_node, exit_node],
|
||||
)
|
||||
elif loop_exit_node == "model":
|
||||
@@ -867,9 +886,10 @@ def create_agent( # noqa: PLR0915
|
||||
else:
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_after_model[0].name}.after_model",
|
||||
exit_node,
|
||||
loop_entry_node,
|
||||
name=f"{middleware_w_after_model[0].name}.after_model",
|
||||
default_destination=exit_node,
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
|
||||
)
|
||||
|
||||
@@ -878,17 +898,19 @@ def create_agent( # noqa: PLR0915
|
||||
for m1, m2 in itertools.pairwise(middleware_w_before_agent):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.name}.before_agent",
|
||||
f"{m2.name}.before_agent",
|
||||
loop_entry_node,
|
||||
name=f"{m1.name}.before_agent",
|
||||
default_destination=f"{m2.name}.before_agent",
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(m1, "before_agent"),
|
||||
)
|
||||
# Connect last before_agent to loop_entry_node (before_model or model)
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_before_agent[-1].name}.before_agent",
|
||||
loop_entry_node,
|
||||
loop_entry_node,
|
||||
name=f"{middleware_w_before_agent[-1].name}.before_agent",
|
||||
default_destination=loop_entry_node,
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
|
||||
)
|
||||
|
||||
@@ -897,17 +919,19 @@ def create_agent( # noqa: PLR0915
|
||||
for m1, m2 in itertools.pairwise(middleware_w_before_model):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.name}.before_model",
|
||||
f"{m2.name}.before_model",
|
||||
loop_entry_node,
|
||||
name=f"{m1.name}.before_model",
|
||||
default_destination=f"{m2.name}.before_model",
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(m1, "before_model"),
|
||||
)
|
||||
# Go directly to model after the last before_model
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_before_model[-1].name}.before_model",
|
||||
"model",
|
||||
loop_entry_node,
|
||||
name=f"{middleware_w_before_model[-1].name}.before_model",
|
||||
default_destination="model",
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
|
||||
)
|
||||
|
||||
@@ -919,9 +943,10 @@ def create_agent( # noqa: PLR0915
|
||||
m2 = middleware_w_after_model[idx - 1]
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.name}.after_model",
|
||||
f"{m2.name}.after_model",
|
||||
loop_entry_node,
|
||||
name=f"{m1.name}.after_model",
|
||||
default_destination=f"{m2.name}.after_model",
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
||||
)
|
||||
# Note: Connection from after_model to after_agent/END is handled above
|
||||
@@ -935,18 +960,20 @@ def create_agent( # noqa: PLR0915
|
||||
m2 = middleware_w_after_agent[idx - 1]
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.name}.after_agent",
|
||||
f"{m2.name}.after_agent",
|
||||
loop_entry_node,
|
||||
name=f"{m1.name}.after_agent",
|
||||
default_destination=f"{m2.name}.after_agent",
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(m1, "after_agent"),
|
||||
)
|
||||
|
||||
# Connect the last after_agent to END
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_after_agent[0].name}.after_agent",
|
||||
END,
|
||||
loop_entry_node,
|
||||
name=f"{middleware_w_after_agent[0].name}.after_agent",
|
||||
default_destination=END,
|
||||
model_destination=loop_entry_node,
|
||||
end_destination=exit_node,
|
||||
can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
|
||||
)
|
||||
|
||||
@@ -961,11 +988,16 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
|
||||
|
||||
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
||||
def _resolve_jump(
|
||||
jump_to: JumpTo | None,
|
||||
*,
|
||||
model_destination: str,
|
||||
end_destination: str,
|
||||
) -> str | None:
|
||||
if jump_to == "model":
|
||||
return first_node
|
||||
return model_destination
|
||||
if jump_to == "end":
|
||||
return "__end__"
|
||||
return end_destination
|
||||
if jump_to == "tools":
|
||||
return "tools"
|
||||
return None
|
||||
@@ -988,17 +1020,22 @@ def _fetch_last_ai_and_tool_messages(
|
||||
|
||||
|
||||
def _make_model_to_tools_edge(
|
||||
first_node: str,
|
||||
*,
|
||||
model_destination: str,
|
||||
structured_output_tools: dict[str, OutputToolBinding],
|
||||
tool_node: ToolNode,
|
||||
exit_node: str,
|
||||
end_destination: str,
|
||||
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
|
||||
def model_to_tools(
|
||||
state: dict[str, Any], runtime: Runtime[ContextT]
|
||||
) -> str | list[Send] | None:
|
||||
# 1. if there's an explicit jump_to in the state, use it
|
||||
if jump_to := state.get("jump_to"):
|
||||
return _resolve_jump(jump_to, first_node)
|
||||
return _resolve_jump(
|
||||
jump_to,
|
||||
model_destination=model_destination,
|
||||
end_destination=end_destination,
|
||||
)
|
||||
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
||||
@@ -1006,7 +1043,7 @@ def _make_model_to_tools_edge(
|
||||
# 2. if the model hasn't called any tools, exit the loop
|
||||
# this is the classic exit condition for an agent loop
|
||||
if len(last_ai_message.tool_calls) == 0:
|
||||
return exit_node
|
||||
return end_destination
|
||||
|
||||
pending_tool_calls = [
|
||||
c
|
||||
@@ -1024,18 +1061,19 @@ def _make_model_to_tools_edge(
|
||||
|
||||
# 4. if there is a structured response, exit the loop
|
||||
if "structured_response" in state:
|
||||
return exit_node
|
||||
return end_destination
|
||||
|
||||
# 5. AIMessage has tool calls, but there are no pending tool calls
|
||||
# which suggests the injection of artificial tool messages. jump to the first node
|
||||
return first_node
|
||||
# which suggests the injection of artificial tool messages. jump to the model node
|
||||
return model_destination
|
||||
|
||||
return model_to_tools
|
||||
|
||||
|
||||
def _make_model_to_model_edge(
|
||||
first_node: str,
|
||||
exit_node: str,
|
||||
*,
|
||||
model_destination: str,
|
||||
end_destination: str,
|
||||
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
|
||||
def model_to_model(
|
||||
state: dict[str, Any],
|
||||
@@ -1043,24 +1081,29 @@ def _make_model_to_model_edge(
|
||||
) -> str | list[Send] | None:
|
||||
# 1. Priority: Check for explicit jump_to directive from middleware
|
||||
if jump_to := state.get("jump_to"):
|
||||
return _resolve_jump(jump_to, first_node)
|
||||
return _resolve_jump(
|
||||
jump_to,
|
||||
model_destination=model_destination,
|
||||
end_destination=end_destination,
|
||||
)
|
||||
|
||||
# 2. Exit condition: A structured response was generated
|
||||
if "structured_response" in state:
|
||||
return exit_node
|
||||
return end_destination
|
||||
|
||||
# 3. Default: Continue the loop, there may have been an issue
|
||||
# with structured output generation, so we need to retry
|
||||
return first_node
|
||||
return model_destination
|
||||
|
||||
return model_to_model
|
||||
|
||||
|
||||
def _make_tools_to_model_edge(
|
||||
*,
|
||||
tool_node: ToolNode,
|
||||
next_node: str,
|
||||
model_destination: str,
|
||||
structured_output_tools: dict[str, OutputToolBinding],
|
||||
exit_node: str,
|
||||
end_destination: str,
|
||||
) -> Callable[[dict[str, Any], Runtime[ContextT]], str | None]:
|
||||
def tools_to_model(state: dict[str, Any], runtime: Runtime[ContextT]) -> str | None: # noqa: ARG001
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
@@ -1071,25 +1114,27 @@ def _make_tools_to_model_edge(
|
||||
for c in last_ai_message.tool_calls
|
||||
if c["name"] in tool_node.tools_by_name
|
||||
):
|
||||
return exit_node
|
||||
return end_destination
|
||||
|
||||
# 2. Exit condition: A structured output tool was executed
|
||||
if any(t.name in structured_output_tools for t in tool_messages):
|
||||
return exit_node
|
||||
return end_destination
|
||||
|
||||
# 3. Default: Continue the loop
|
||||
# Tool execution completed successfully, route back to the model
|
||||
# so it can process the tool results and decide the next action.
|
||||
return next_node
|
||||
return model_destination
|
||||
|
||||
return tools_to_model
|
||||
|
||||
|
||||
def _add_middleware_edge(
|
||||
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
||||
*,
|
||||
name: str,
|
||||
default_destination: str,
|
||||
model_destination: str,
|
||||
end_destination: str,
|
||||
can_jump_to: list[JumpTo] | None,
|
||||
) -> None:
|
||||
"""Add an edge to the graph for a middleware node.
|
||||
@@ -1099,17 +1144,25 @@ def _add_middleware_edge(
|
||||
name: The name of the middleware node.
|
||||
default_destination: The default destination for the edge.
|
||||
model_destination: The destination for the edge to the model.
|
||||
end_destination: The destination for the edge to the end.
|
||||
can_jump_to: The conditionally jumpable destinations for the edge.
|
||||
"""
|
||||
if can_jump_to:
|
||||
|
||||
def jump_edge(state: dict[str, Any]) -> str:
|
||||
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
||||
return (
|
||||
_resolve_jump(
|
||||
state.get("jump_to"),
|
||||
model_destination=model_destination,
|
||||
end_destination=end_destination,
|
||||
)
|
||||
or default_destination
|
||||
)
|
||||
|
||||
destinations = [default_destination]
|
||||
|
||||
if "end" in can_jump_to:
|
||||
destinations.append(END)
|
||||
destinations.append(end_destination)
|
||||
if "tools" in can_jump_to:
|
||||
destinations.append("tools")
|
||||
if "model" in can_jump_to and name != model_destination:
|
||||
|
||||
@@ -1,4 +1,34 @@
|
||||
# serializer version: 1
|
||||
# name: test_agent_graph_with_jump_to_end_as_after_agent
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
NoopZero\2ebefore_agent(NoopZero.before_agent)
|
||||
NoopOne\2eafter_agent(NoopOne.after_agent)
|
||||
NoopTwo\2eafter_agent(NoopTwo.after_agent)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopTwo\2eafter_agent --> NoopOne\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> NoopTwo\2eafter_agent;
|
||||
NoopZero\2ebefore_agent -.-> model;
|
||||
__start__ --> NoopZero\2ebefore_agent;
|
||||
model -.-> NoopTwo\2eafter_agent;
|
||||
model -.-> tools;
|
||||
tools -.-> NoopTwo\2eafter_agent;
|
||||
tools -.-> model;
|
||||
NoopOne\2eafter_agent --> __end__;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_create_agent_diagram
|
||||
'''
|
||||
---
|
||||
@@ -452,3 +482,26 @@
|
||||
|
||||
'''
|
||||
# ---
|
||||
# name: test_simple_agent_graph
|
||||
'''
|
||||
---
|
||||
config:
|
||||
flowchart:
|
||||
curve: linear
|
||||
---
|
||||
graph TD;
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model(model)
|
||||
tools(tools)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
__start__ --> model;
|
||||
model -.-> __end__;
|
||||
model -.-> tools;
|
||||
tools -.-> __end__;
|
||||
tools -.-> model;
|
||||
classDef default fill:#f2f0ff,line-height:1.2
|
||||
classDef first fill-opacity:0
|
||||
classDef last fill:#bfb6fc
|
||||
|
||||
'''
|
||||
# ---
|
||||
|
||||
@@ -383,6 +383,54 @@ def test_create_agent_jump(
|
||||
assert calls == ["NoopSeven.before_model", "NoopEight.before_model"]
|
||||
|
||||
|
||||
def test_simple_agent_graph(snapshot: SnapshotAssertion) -> None:
|
||||
@tool
|
||||
def my_tool(input_string: str) -> str:
|
||||
"""A great tool."""
|
||||
return input_string
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[[ToolCall(id="1", name="my_tool", args={"input": "yo"})]],
|
||||
),
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
|
||||
def test_agent_graph_with_jump_to_end_as_after_agent(snapshot: SnapshotAssertion) -> None:
|
||||
@tool
|
||||
def my_tool(input_string: str) -> str:
|
||||
"""A great tool."""
|
||||
return input_string
|
||||
|
||||
class NoopZero(AgentMiddleware):
|
||||
@hook_config(can_jump_to=["end"])
|
||||
def before_agent(self, state, runtime) -> None:
|
||||
return None
|
||||
|
||||
class NoopOne(AgentMiddleware):
|
||||
def after_agent(self, state, runtime) -> None:
|
||||
return None
|
||||
|
||||
class NoopTwo(AgentMiddleware):
|
||||
def after_agent(self, state, runtime) -> None:
|
||||
return None
|
||||
|
||||
agent_one = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[[ToolCall(id="1", name="my_tool", args={"input": "yo"})]],
|
||||
),
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[NoopZero(), NoopOne(), NoopTwo()],
|
||||
)
|
||||
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
|
||||
# Tests for HumanInTheLoopMiddleware
|
||||
def test_human_in_the_loop_middleware_initialization() -> None:
|
||||
"""Test HumanInTheLoopMiddleware initialization."""
|
||||
|
||||
Reference in New Issue
Block a user