mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 01:37:59 +00:00
community: Add conversions from GVS to networkx (#26906)
These allow converting linked documents (such as those used with GraphVectorStore) to networkx for rendering and/or in-memory graph algorithms such as community detection.
This commit is contained in:
parent
7809b31b95
commit
29bf89db25
@ -45,6 +45,7 @@ motor>=3.3.1,<4
|
||||
msal>=1.25.0,<2
|
||||
mwparserfromhell>=0.6.4,<0.7
|
||||
mwxml>=0.3.3,<0.4
|
||||
networkx>=3.2.1,<4
|
||||
newspaper3k>=0.2.8,<0.3
|
||||
numexpr>=2.8.6,<3
|
||||
nvidia-riva-client>=2.14.0,<3
|
||||
|
@ -0,0 +1,84 @@
|
||||
"""Utilities for using Graph Vector Stores with networkx."""
|
||||
|
||||
import typing
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.graph_vectorstores.links import get_links
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import networkx as nx
|
||||
|
||||
|
||||
def documents_to_networkx(
|
||||
documents: typing.Iterable[Document],
|
||||
*,
|
||||
tag_nodes: bool = True,
|
||||
) -> "nx.DiGraph":
|
||||
"""Return the networkx directed graph corresponding to the documents.
|
||||
|
||||
Args:
|
||||
documents: The documents to convenrt to networkx.
|
||||
tag_nodes: If `True`, each tag will be rendered as a node, with edges
|
||||
to/from the corresponding documents. If `False`, edges will be
|
||||
between documents, with a label corresponding to the tag(s)
|
||||
connecting them.
|
||||
"""
|
||||
import networkx as nx
|
||||
|
||||
graph = nx.DiGraph()
|
||||
|
||||
tag_ids: typing.Dict[typing.Tuple[str, str], str] = {}
|
||||
tag_labels: typing.Dict[str, str] = {}
|
||||
documents_by_incoming: typing.Dict[str, typing.Set[str]] = {}
|
||||
|
||||
# First pass:
|
||||
# - Register tag IDs for each unique (kind, tag).
|
||||
# - If rendering tag nodes, add them to the graph.
|
||||
# - If not rendering tag nodes, create a dictionary of documents by incoming tags.
|
||||
for document in documents:
|
||||
if document.id is None:
|
||||
raise ValueError(f"Illegal graph document without ID: {document}")
|
||||
|
||||
for link in get_links(document):
|
||||
tag_key = (link.kind, link.tag)
|
||||
tag_id = tag_ids.get(tag_key)
|
||||
if tag_id is None:
|
||||
tag_id = f"tag_{len(tag_ids)}"
|
||||
tag_ids[tag_key] = tag_id
|
||||
|
||||
if tag_nodes:
|
||||
graph.add_node(tag_id, label=f"{link.kind}:{link.tag}")
|
||||
|
||||
if not tag_nodes and (link.direction == "in" or link.direction == "bidir"):
|
||||
tag_labels[tag_id] = f"{link.kind}:{link.tag}"
|
||||
documents_by_incoming.setdefault(tag_id, set()).add(document.id)
|
||||
|
||||
# Second pass:
|
||||
# - Render document nodes
|
||||
# - If rendering tag nodes, render edges to/from documents and tag nodes.
|
||||
# - If not rendering tag nodes, render edges to/from documents based on tags.
|
||||
for document in documents:
|
||||
graph.add_node(document.id, text=document.page_content)
|
||||
|
||||
targets: typing.Dict[str, typing.List[str]] = {}
|
||||
for link in get_links(document):
|
||||
tag_id = tag_ids[(link.kind, link.tag)]
|
||||
if tag_nodes:
|
||||
if link.direction == "in" or link.direction == "bidir":
|
||||
graph.add_edge(tag_id, document.id)
|
||||
if link.direction == "out" or link.direction == "bidir":
|
||||
graph.add_edge(document.id, tag_id)
|
||||
else:
|
||||
if link.direction == "out" or link.direction == "bidir":
|
||||
label = tag_labels[tag_id]
|
||||
for target in documents_by_incoming[tag_id]:
|
||||
if target != document.id:
|
||||
targets.setdefault(target, []).append(label)
|
||||
|
||||
# Avoid a multigraph by collecting the list of labels for each edge.
|
||||
if not tag_nodes:
|
||||
for target, labels in targets.items():
|
||||
graph.add_edge(document.id, target, label=str(labels))
|
||||
|
||||
return graph
|
@ -0,0 +1,77 @@
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
|
||||
from langchain_community.graph_vectorstores.networkx import documents_to_networkx
|
||||
|
||||
|
||||
@pytest.mark.requires("networkx")
|
||||
def test_documents_to_networkx() -> None:
|
||||
import networkx as nx
|
||||
|
||||
doc1 = Document(
|
||||
id="a",
|
||||
page_content="some content",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming("href", "a"),
|
||||
Link.bidir("kw", "foo"),
|
||||
]
|
||||
},
|
||||
)
|
||||
doc2 = Document(
|
||||
id="b",
|
||||
page_content="<some\n more content>",
|
||||
metadata={
|
||||
METADATA_LINKS_KEY: [
|
||||
Link.incoming("href", "b"),
|
||||
Link.outgoing("href", "a"),
|
||||
Link.bidir("kw", "foo"),
|
||||
Link.bidir("kw", "bar"),
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
graph_with_tags = documents_to_networkx([doc1, doc2], tag_nodes=True)
|
||||
link_data = nx.node_link_data(graph_with_tags)
|
||||
assert link_data["directed"]
|
||||
assert not link_data["multigraph"]
|
||||
|
||||
link_data["nodes"].sort(key=lambda n: n["id"])
|
||||
assert link_data["nodes"] == [
|
||||
{"id": "a", "text": "some content"},
|
||||
{"id": "b", "text": "<some\n more content>"},
|
||||
{"id": "tag_0", "label": "href:a"},
|
||||
{"id": "tag_1", "label": "kw:foo"},
|
||||
{"id": "tag_2", "label": "href:b"},
|
||||
{"id": "tag_3", "label": "kw:bar"},
|
||||
]
|
||||
link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
|
||||
assert link_data["links"] == [
|
||||
{"source": "a", "target": "tag_1"},
|
||||
{"source": "b", "target": "tag_0"},
|
||||
{"source": "b", "target": "tag_1"},
|
||||
{"source": "b", "target": "tag_3"},
|
||||
{"source": "tag_0", "target": "a"},
|
||||
{"source": "tag_1", "target": "a"},
|
||||
{"source": "tag_1", "target": "b"},
|
||||
{"source": "tag_2", "target": "b"},
|
||||
{"source": "tag_3", "target": "b"},
|
||||
]
|
||||
|
||||
graph_without_tags = documents_to_networkx([doc1, doc2], tag_nodes=False)
|
||||
link_data = nx.node_link_data(graph_without_tags)
|
||||
assert link_data["directed"]
|
||||
assert not link_data["multigraph"]
|
||||
|
||||
link_data["nodes"].sort(key=lambda n: n["id"])
|
||||
assert link_data["nodes"] == [
|
||||
{"id": "a", "text": "some content"},
|
||||
{"id": "b", "text": "<some\n more content>"},
|
||||
]
|
||||
|
||||
link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
|
||||
assert link_data["links"] == [
|
||||
{"source": "a", "target": "b", "label": "['kw:foo']"},
|
||||
{"source": "b", "target": "a", "label": "['href:a', 'kw:foo']"},
|
||||
]
|
Loading…
Reference in New Issue
Block a user