mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
core[patch]: Fixed trim functions, and added corresponding unit test for the solved issue (#28429)
- **Description:** - Trim functions were incorrectly deleting nodes with more than 1 outgoing/incoming edge, so an extra condition was added to check for this directly. A unit test "test_trim_multi_edge" was written to test this test case specifically. - **Issue:** - Fixes #28411 - Fixes https://github.com/langchain-ai/langgraph/issues/1676 - **Dependencies:** - No changes were made to the dependencies - [x] Unit tests were added to verify the changes. - [x] Updated documentation where necessary. - [x] Ran make format, make lint, and make test to ensure compliance with project standards. --------- Co-authored-by: Tasif Hussain <tasif006@gmail.com>
This commit is contained in:
parent
54fba7e520
commit
481c4bfaba
@ -470,14 +470,22 @@ class Graph:
|
|||||||
"""Remove the first node if it exists and has a single outgoing edge,
|
"""Remove the first node if it exists and has a single outgoing edge,
|
||||||
i.e., if removing it would not leave the graph without a "first" node."""
|
i.e., if removing it would not leave the graph without a "first" node."""
|
||||||
first_node = self.first_node()
|
first_node = self.first_node()
|
||||||
if first_node and _first_node(self, exclude=[first_node.id]):
|
if (
|
||||||
|
first_node
|
||||||
|
and _first_node(self, exclude=[first_node.id])
|
||||||
|
and len({e for e in self.edges if e.source == first_node.id}) == 1
|
||||||
|
):
|
||||||
self.remove_node(first_node)
|
self.remove_node(first_node)
|
||||||
|
|
||||||
def trim_last_node(self) -> None:
|
def trim_last_node(self) -> None:
|
||||||
"""Remove the last node if it exists and has a single incoming edge,
|
"""Remove the last node if it exists and has a single incoming edge,
|
||||||
i.e., if removing it would not leave the graph without a "last" node."""
|
i.e., if removing it would not leave the graph without a "last" node."""
|
||||||
last_node = self.last_node()
|
last_node = self.last_node()
|
||||||
if last_node and _last_node(self, exclude=[last_node.id]):
|
if (
|
||||||
|
last_node
|
||||||
|
and _last_node(self, exclude=[last_node.id])
|
||||||
|
and len({e for e in self.edges if e.target == last_node.id}) == 1
|
||||||
|
):
|
||||||
self.remove_node(last_node)
|
self.remove_node(last_node)
|
||||||
|
|
||||||
def draw_ascii(self) -> str:
|
def draw_ascii(self) -> str:
|
||||||
|
@ -69,6 +69,26 @@ def test_trim(snapshot: SnapshotAssertion) -> None:
|
|||||||
assert graph.last_node() is end
|
assert graph.last_node() is end
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_multi_edge() -> None:
|
||||||
|
class Scheme(BaseModel):
|
||||||
|
a: str
|
||||||
|
|
||||||
|
graph = Graph()
|
||||||
|
start = graph.add_node(Scheme, id="__start__")
|
||||||
|
a = graph.add_node(Scheme, id="a")
|
||||||
|
last = graph.add_node(Scheme, id="__end__")
|
||||||
|
|
||||||
|
graph.add_edge(start, a)
|
||||||
|
graph.add_edge(a, last)
|
||||||
|
graph.add_edge(start, last)
|
||||||
|
|
||||||
|
graph.trim_first_node() # should not remove __start__ since it has 2 outgoing edges
|
||||||
|
assert graph.first_node() is start
|
||||||
|
|
||||||
|
graph.trim_last_node() # should not remove the __end__ node since it has 2 incoming edges
|
||||||
|
assert graph.last_node() is last
|
||||||
|
|
||||||
|
|
||||||
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
def test_graph_sequence(snapshot: SnapshotAssertion) -> None:
|
||||||
fake_llm = FakeListLLM(responses=["a"])
|
fake_llm = FakeListLLM(responses=["a"])
|
||||||
prompt = PromptTemplate.from_template("Hello, {name}!")
|
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||||
|
Loading…
Reference in New Issue
Block a user