diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 03b18402297..b064ff567f3 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -13,6 +13,7 @@ from typing import ( NamedTuple, Optional, Protocol, + Sequence, Tuple, Type, TypedDict, @@ -448,48 +449,27 @@ class Graph: """Find the single node that is not a target of any edge. If there is no such node, or there are multiple, return None. When drawing the graph, this node would be the origin.""" - targets = {edge.target for edge in self.edges} - found: List[Node] = [] - for node in self.nodes.values(): - if node.id not in targets: - found.append(node) - return found[0] if len(found) == 1 else None + return _first_node(self) def last_node(self) -> Optional[Node]: """Find the single node that is not a source of any edge. If there is no such node, or there are multiple, return None. - When drawing the graph, this node would be the destination. - """ - sources = {edge.source for edge in self.edges} - found: List[Node] = [] - for node in self.nodes.values(): - if node.id not in sources: - found.append(node) - return found[0] if len(found) == 1 else None + When drawing the graph, this node would be the destination.""" + return _last_node(self) def trim_first_node(self) -> None: """Remove the first node if it exists and has a single outgoing edge, i.e., if removing it would not leave the graph without a "first" node.""" first_node = self.first_node() - if first_node: - if ( - len(self.nodes) == 1 - or len([edge for edge in self.edges if edge.source == first_node.id]) - == 1 - ): - self.remove_node(first_node) + if first_node and _first_node(self, exclude=[first_node.id]): + self.remove_node(first_node) def trim_last_node(self) -> None: """Remove the last node if it exists and has a single incoming edge, i.e., if removing it would not leave the graph without a "last" node.""" last_node = self.last_node() - if last_node: - if ( - len(self.nodes) == 1 - or len([edge for edge in self.edges if edge.target == last_node.id]) - == 1 - ): - self.remove_node(last_node) + if last_node and _last_node(self, exclude=[last_node.id]): + self.remove_node(last_node) def draw_ascii(self) -> str: """Draw the graph as an ASCII art string.""" @@ -631,3 +611,29 @@ class Graph: background_color=background_color, padding=padding, ) + + +def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]: + """Find the single node that is not a target of any edge. + Exclude nodes/sources with ids in the exclude list. + If there is no such node, or there are multiple, return None. + When drawing the graph, this node would be the origin.""" + targets = {edge.target for edge in graph.edges if edge.source not in exclude} + found: List[Node] = [] + for node in graph.nodes.values(): + if node.id not in exclude and node.id not in targets: + found.append(node) + return found[0] if len(found) == 1 else None + + +def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]: + """Find the single node that is not a source of any edge. + Exclude nodes/targets with ids in the exclude list. + If there is no such node, or there are multiple, return None. + When drawing the graph, this node would be the destination.""" + sources = {edge.source for edge in graph.edges if edge.target not in exclude} + found: List[Node] = [] + for node in graph.nodes.values(): + if node.id not in exclude and node.id not in sources: + found.append(node) + return found[0] if len(found) == 1 else None diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index 75bdc180919..73b19f29e75 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -1097,3 +1097,65 @@ ''' # --- +# name: test_trim + dict({ + 'edges': list([ + dict({ + 'source': '__start__', + 'target': 'ask_question', + }), + dict({ + 'source': 'ask_question', + 'target': 'answer_question', + }), + dict({ + 'conditional': True, + 'source': 'answer_question', + 'target': 'ask_question', + }), + dict({ + 'conditional': True, + 'source': 'answer_question', + 'target': '__end__', + }), + ]), + 'nodes': list([ + dict({ + 'data': '__start__', + 'id': '__start__', + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'schema', + 'output_parser', + 'StrOutputParser', + ]), + 'name': 'ask_question', + }), + 'id': 'ask_question', + 'type': 'runnable', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'schema', + 'output_parser', + 'StrOutputParser', + ]), + 'name': 'answer_question', + }), + 'id': 'answer_question', + 'type': 'runnable', + }), + dict({ + 'data': '__end__', + 'id': '__end__', + 'type': 'schema', + }), + ]), + }) +# --- diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 975d6b816d5..8e60092f7ad 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -7,7 +7,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables.base import Runnable, RunnableConfig +from langchain_core.runnables.graph import Graph from langchain_core.runnables.graph_mermaid import _escape_node_label @@ -27,6 +29,42 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None: assert graph.draw_ascii() == snapshot(name="ascii") assert graph.draw_mermaid() == snapshot(name="mermaid") + graph.trim_first_node() + first_node = graph.first_node() + assert first_node is not None + assert first_node.data == runnable + + graph.trim_last_node() + last_node = graph.last_node() + assert last_node is not None + assert last_node.data == runnable + + +def test_trim(snapshot: SnapshotAssertion) -> None: + runnable = StrOutputParser() + + class Schema(BaseModel): + a: int + + graph = Graph() + start = graph.add_node(Schema, id="__start__") + ask = graph.add_node(runnable, id="ask_question") + answer = graph.add_node(runnable, id="answer_question") + end = graph.add_node(Schema, id="__end__") + graph.add_edge(start, ask) + graph.add_edge(ask, answer) + graph.add_edge(answer, ask, conditional=True) + graph.add_edge(answer, end, conditional=True) + + assert graph.to_json() == snapshot + assert graph.first_node() is start + assert graph.last_node() is end + # can't trim start or end node + graph.trim_first_node() + assert graph.first_node() is start + graph.trim_last_node() + assert graph.last_node() is end + def test_graph_sequence(snapshot: SnapshotAssertion) -> None: fake_llm = FakeListLLM(responses=["a"])