mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 06:33:20 +00:00
core: Fix implementation of trim_first_node/trim_last_node to use exact same definition of first/last node as in the getter methods (#24802)
This commit is contained in:
parent
c2706cfb9e
commit
68ecebf1ec
@ -13,6 +13,7 @@ from typing import (
|
|||||||
NamedTuple,
|
NamedTuple,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
@ -448,48 +449,27 @@ class Graph:
|
|||||||
"""Find the single node that is not a target of any edge.
|
"""Find the single node that is not a target of any edge.
|
||||||
If there is no such node, or there are multiple, return None.
|
If there is no such node, or there are multiple, return None.
|
||||||
When drawing the graph, this node would be the origin."""
|
When drawing the graph, this node would be the origin."""
|
||||||
targets = {edge.target for edge in self.edges}
|
return _first_node(self)
|
||||||
found: List[Node] = []
|
|
||||||
for node in self.nodes.values():
|
|
||||||
if node.id not in targets:
|
|
||||||
found.append(node)
|
|
||||||
return found[0] if len(found) == 1 else None
|
|
||||||
|
|
||||||
def last_node(self) -> Optional[Node]:
|
def last_node(self) -> Optional[Node]:
|
||||||
"""Find the single node that is not a source of any edge.
|
"""Find the single node that is not a source of any edge.
|
||||||
If there is no such node, or there are multiple, return None.
|
If there is no such node, or there are multiple, return None.
|
||||||
When drawing the graph, this node would be the destination.
|
When drawing the graph, this node would be the destination."""
|
||||||
"""
|
return _last_node(self)
|
||||||
sources = {edge.source for edge in self.edges}
|
|
||||||
found: List[Node] = []
|
|
||||||
for node in self.nodes.values():
|
|
||||||
if node.id not in sources:
|
|
||||||
found.append(node)
|
|
||||||
return found[0] if len(found) == 1 else None
|
|
||||||
|
|
||||||
def trim_first_node(self) -> None:
|
def trim_first_node(self) -> None:
|
||||||
"""Remove the first node if it exists and has a single outgoing edge,
|
"""Remove the first node if it exists and has a single outgoing edge,
|
||||||
i.e., if removing it would not leave the graph without a "first" node."""
|
i.e., if removing it would not leave the graph without a "first" node."""
|
||||||
first_node = self.first_node()
|
first_node = self.first_node()
|
||||||
if first_node:
|
if first_node and _first_node(self, exclude=[first_node.id]):
|
||||||
if (
|
self.remove_node(first_node)
|
||||||
len(self.nodes) == 1
|
|
||||||
or len([edge for edge in self.edges if edge.source == first_node.id])
|
|
||||||
== 1
|
|
||||||
):
|
|
||||||
self.remove_node(first_node)
|
|
||||||
|
|
||||||
def trim_last_node(self) -> None:
|
def trim_last_node(self) -> None:
|
||||||
"""Remove the last node if it exists and has a single incoming edge,
|
"""Remove the last node if it exists and has a single incoming edge,
|
||||||
i.e., if removing it would not leave the graph without a "last" node."""
|
i.e., if removing it would not leave the graph without a "last" node."""
|
||||||
last_node = self.last_node()
|
last_node = self.last_node()
|
||||||
if last_node:
|
if last_node and _last_node(self, exclude=[last_node.id]):
|
||||||
if (
|
self.remove_node(last_node)
|
||||||
len(self.nodes) == 1
|
|
||||||
or len([edge for edge in self.edges if edge.target == last_node.id])
|
|
||||||
== 1
|
|
||||||
):
|
|
||||||
self.remove_node(last_node)
|
|
||||||
|
|
||||||
def draw_ascii(self) -> str:
|
def draw_ascii(self) -> str:
|
||||||
"""Draw the graph as an ASCII art string."""
|
"""Draw the graph as an ASCII art string."""
|
||||||
@ -631,3 +611,29 @@ class Graph:
|
|||||||
background_color=background_color,
|
background_color=background_color,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
||||||
|
"""Find the single node that is not a target of any edge.
|
||||||
|
Exclude nodes/sources with ids in the exclude list.
|
||||||
|
If there is no such node, or there are multiple, return None.
|
||||||
|
When drawing the graph, this node would be the origin."""
|
||||||
|
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
|
||||||
|
found: List[Node] = []
|
||||||
|
for node in graph.nodes.values():
|
||||||
|
if node.id not in exclude and node.id not in targets:
|
||||||
|
found.append(node)
|
||||||
|
return found[0] if len(found) == 1 else None
|
||||||
|
|
||||||
|
|
||||||
|
def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
|
||||||
|
"""Find the single node that is not a source of any edge.
|
||||||
|
Exclude nodes/targets with ids in the exclude list.
|
||||||
|
If there is no such node, or there are multiple, return None.
|
||||||
|
When drawing the graph, this node would be the destination."""
|
||||||
|
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
|
||||||
|
found: List[Node] = []
|
||||||
|
for node in graph.nodes.values():
|
||||||
|
if node.id not in exclude and node.id not in sources:
|
||||||
|
found.append(node)
|
||||||
|
return found[0] if len(found) == 1 else None
|
||||||
|
@ -1097,3 +1097,65 @@
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_trim
|
||||||
|
dict({
|
||||||
|
'edges': list([
|
||||||
|
dict({
|
||||||
|
'source': '__start__',
|
||||||
|
'target': 'ask_question',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'source': 'ask_question',
|
||||||
|
'target': 'answer_question',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'conditional': True,
|
||||||
|
'source': 'answer_question',
|
||||||
|
'target': 'ask_question',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'conditional': True,
|
||||||
|
'source': 'answer_question',
|
||||||
|
'target': '__end__',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
'nodes': list([
|
||||||
|
dict({
|
||||||
|
'data': '__start__',
|
||||||
|
'id': '__start__',
|
||||||
|
'type': 'schema',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'id': list([
|
||||||
|
'langchain',
|
||||||
|
'schema',
|
||||||
|
'output_parser',
|
||||||
|
'StrOutputParser',
|
||||||
|
]),
|
||||||
|
'name': 'ask_question',
|
||||||
|
}),
|
||||||
|
'id': 'ask_question',
|
||||||
|
'type': 'runnable',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'id': list([
|
||||||
|
'langchain',
|
||||||
|
'schema',
|
||||||
|
'output_parser',
|
||||||
|
'StrOutputParser',
|
||||||
|
]),
|
||||||
|
'name': 'answer_question',
|
||||||
|
}),
|
||||||
|
'id': 'answer_question',
|
||||||
|
'type': 'runnable',
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': '__end__',
|
||||||
|
'id': '__end__',
|
||||||
|
'type': 'schema',
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
@ -7,7 +7,9 @@ from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
|||||||
from langchain_core.output_parsers.string import StrOutputParser
|
from langchain_core.output_parsers.string import StrOutputParser
|
||||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables.base import Runnable, RunnableConfig
|
from langchain_core.runnables.base import Runnable, RunnableConfig
|
||||||
|
from langchain_core.runnables.graph import Graph
|
||||||
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
from langchain_core.runnables.graph_mermaid import _escape_node_label
|
||||||
|
|
||||||
|
|
||||||
@ -27,6 +29,42 @@ def test_graph_single_runnable(snapshot: SnapshotAssertion) -> None:
|
|||||||
assert graph.draw_ascii() == snapshot(name="ascii")
|
assert graph.draw_ascii() == snapshot(name="ascii")
|
||||||
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
assert graph.draw_mermaid() == snapshot(name="mermaid")
|
||||||
|
|
||||||
|
graph.trim_first_node()
|
||||||
|
first_node = graph.first_node()
|
||||||
|
assert first_node is not None
|
||||||
|
assert first_node.data == runnable
|
||||||
|
|
||||||
|
graph.trim_last_node()
|
||||||
|
last_node = graph.last_node()
|
||||||
|
assert last_node is not None
|
||||||
|
assert last_node.data == runnable
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim(snapshot: SnapshotAssertion) -> None:
|
||||||
|
runnable = StrOutputParser()
|
||||||
|
|
||||||
|
class Schema(BaseModel):
|
||||||
|
a: int
|
||||||
|
|
||||||
|
graph = Graph()
|
||||||
|
start = graph.add_node(Schema, id="__start__")
|
||||||
|
ask = graph.add_node(runnable, id="ask_question")
|
||||||
|
answer = graph.add_node(runnable, id="answer_question")
|
||||||
|
end = graph.add_node(Schema, id="__end__")
|
||||||
|
graph.add_edge(start, ask)
|
||||||
|
graph.add_edge(ask, answer)
|
||||||
|
graph.add_edge(answer, ask, conditional=True)
|
||||||
|
graph.add_edge(answer, end, conditional=True)
|
||||||
|
|
||||||
|
assert graph.to_json() == snapshot
|
||||||
|
assert graph.first_node() is start
|
||||||
|
assert graph.last_node() is end
|
||||||
|
# can't trim start or end node
|
||||||
|
graph.trim_first_node()
|
||||||
|
assert graph.first_node() is start
|
||||||
|
graph.trim_last_node()
|
||||||
|
assert graph.last_node() is end
|
||||||
|
|
||||||
|
|
||||||
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||||
fake_llm = FakeListLLM(responses=["a"])
|
fake_llm = FakeListLLM(responses=["a"])
|
||||||
|
Loading…
Reference in New Issue
Block a user