diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 2cb57c4e0fe..d0e69e0b481 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -470,14 +470,22 @@ class Graph: """Remove the first node if it exists and has a single outgoing edge, i.e., if removing it would not leave the graph without a "first" node.""" first_node = self.first_node() - if first_node and _first_node(self, exclude=[first_node.id]): + if ( + first_node + and _first_node(self, exclude=[first_node.id]) + and len({e for e in self.edges if e.source == first_node.id}) == 1 + ): self.remove_node(first_node) def trim_last_node(self) -> None: """Remove the last node if it exists and has a single incoming edge, i.e., if removing it would not leave the graph without a "last" node.""" last_node = self.last_node() - if last_node and _last_node(self, exclude=[last_node.id]): + if ( + last_node + and _last_node(self, exclude=[last_node.id]) + and len({e for e in self.edges if e.target == last_node.id}) == 1 + ): self.remove_node(last_node) def draw_ascii(self) -> str: diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 789875db623..e98a0a18a9c 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -69,6 +69,26 @@ def test_trim(snapshot: SnapshotAssertion) -> None: assert graph.last_node() is end +def test_trim_multi_edge() -> None: + class Scheme(BaseModel): + a: str + + graph = Graph() + start = graph.add_node(Scheme, id="__start__") + a = graph.add_node(Scheme, id="a") + last = graph.add_node(Scheme, id="__end__") + + graph.add_edge(start, a) + graph.add_edge(a, last) + graph.add_edge(start, last) + + graph.trim_first_node() # should not remove __start__ since it has 2 outgoing edges + assert graph.first_node() is start + + graph.trim_last_node() # should not remove the __end__ node since it has 2 incoming edges + assert graph.last_node() is last + + def test_graph_sequence(snapshot: SnapshotAssertion) -> None: fake_llm = FakeListLLM(responses=["a"]) prompt = PromptTemplate.from_template("Hello, {name}!")