core: Fix implementation of trim_first_node/trim_last_node to use exact same definition of first/last node as in the getter methods (#24802)

This commit is contained in:
Nuno Campos 2024-07-30 08:44:27 -07:00 committed by GitHub
parent c2706cfb9e
commit 68ecebf1ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 134 additions and 28 deletions

View File

@ -13,6 +13,7 @@ from typing import (
NamedTuple, NamedTuple,
Optional, Optional,
Protocol, Protocol,
Sequence,
Tuple, Tuple,
Type, Type,
TypedDict, TypedDict,
@ -448,48 +449,27 @@ class Graph:
"""Find the single node that is not a target of any edge. """Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin.""" When drawing the graph, this node would be the origin."""
targets = {edge.target for edge in self.edges} return _first_node(self)
found: List[Node] = []
for node in self.nodes.values():
if node.id not in targets:
found.append(node)
return found[0] if len(found) == 1 else None
def last_node(self) -> Optional[Node]: def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge. """Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination. When drawing the graph, this node would be the destination."""
""" return _last_node(self)
sources = {edge.source for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in sources:
found.append(node)
return found[0] if len(found) == 1 else None
def trim_first_node(self) -> None: def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge, """Remove the first node if it exists and has a single outgoing edge,
i.e., if removing it would not leave the graph without a "first" node.""" i.e., if removing it would not leave the graph without a "first" node."""
first_node = self.first_node() first_node = self.first_node()
if first_node: if first_node and _first_node(self, exclude=[first_node.id]):
if ( self.remove_node(first_node)
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.source == first_node.id])
== 1
):
self.remove_node(first_node)
def trim_last_node(self) -> None: def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge, """Remove the last node if it exists and has a single incoming edge,
i.e., if removing it would not leave the graph without a "last" node.""" i.e., if removing it would not leave the graph without a "last" node."""
last_node = self.last_node() last_node = self.last_node()
if last_node: if last_node and _last_node(self, exclude=[last_node.id]):
if ( self.remove_node(last_node)
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.target == last_node.id])
== 1
):
self.remove_node(last_node)
def draw_ascii(self) -> str: def draw_ascii(self) -> str:
"""Draw the graph as an ASCII art string.""" """Draw the graph as an ASCII art string."""
@ -631,3 +611,29 @@ class Graph:
background_color=background_color, background_color=background_color,
padding=padding, padding=padding,
) )
def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
Exclude nodes/sources with ids in the exclude list.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin."""
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: List[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in targets:
found.append(node)
return found[0] if len(found) == 1 else None
def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
Exclude nodes/targets with ids in the exclude list.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination."""
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: List[Node] = []
for node in graph.nodes.values():
if node.id not in exclude and node.id not in sources:
found.append(node)
return found[0] if len(found) == 1 else None

View File

@ -1097,3 +1097,65 @@
''' '''
# --- # ---
# name: test_trim
dict({
'edges': list([
dict({
'source': '__start__',
'target': 'ask_question',
}),
dict({
'source': 'ask_question',
'target': 'answer_question',
}),
dict({
'conditional': True,
'source': 'answer_question',
'target': 'ask_question',
}),
dict({
'conditional': True,
'source': 'answer_question',
'target': '__end__',
}),
]),
'nodes': list([
dict({
'data': '__start__',
'id': '__start__',
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'schema',
'output_parser',
'StrOutputParser',
]),
'name': 'ask_question',
}),
'id': 'ask_question',
'type': 'runnable',
}),
dict({
'data': dict({
'id': list([
'langchain',
'schema',
'output_parser',
'StrOutputParser',
]),
'name': 'answer_question',
}),
'id': 'answer_question',
'type': 'runnable',
}),
dict({
'data': '__end__',
'id': '__end__',
'type': 'schema',
}),
]),
})
# ---

View File

@ -7,7 +7,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.output_parsers.xml import XMLOutputParser
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableConfig from langchain_core.runnables.base import Runnable, RunnableConfig
from langchain_core.runnables.graph import Graph
from langchain_core.runnables.graph_mermaid import _escape_node_label from langchain_core.runnables.graph_mermaid import _escape_node_label
@ -27,6 +29,42 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
assert graph.draw_ascii() == snapshot(name="ascii") assert graph.draw_ascii() == snapshot(name="ascii")
assert graph.draw_mermaid() == snapshot(name="mermaid") assert graph.draw_mermaid() == snapshot(name="mermaid")
graph.trim_first_node()
first_node = graph.first_node()
assert first_node is not None
assert first_node.data == runnable
graph.trim_last_node()
last_node = graph.last_node()
assert last_node is not None
assert last_node.data == runnable
def test_trim(snapshot: SnapshotAssertion) -> None:
runnable = StrOutputParser()
class Schema(BaseModel):
a: int
graph = Graph()
start = graph.add_node(Schema, id="__start__")
ask = graph.add_node(runnable, id="ask_question")
answer = graph.add_node(runnable, id="answer_question")
end = graph.add_node(Schema, id="__end__")
graph.add_edge(start, ask)
graph.add_edge(ask, answer)
graph.add_edge(answer, ask, conditional=True)
graph.add_edge(answer, end, conditional=True)
assert graph.to_json() == snapshot
assert graph.first_node() is start
assert graph.last_node() is end
# can't trim start or end node
graph.trim_first_node()
assert graph.first_node() is start
graph.trim_last_node()
assert graph.last_node() is end
def test_graph_sequence(snapshot: SnapshotAssertion) -> None: def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["a"]) fake_llm = FakeListLLM(responses=["a"])