community: Add conversions from GVS to networkx (#26906)

These allow converting linked documents (such as those used with
GraphVectorStore) to networkx for rendering and/or in-memory graph
algorithms such as community detection.
This commit is contained in:
Ben Chambers 2024-09-27 13:48:55 -07:00 committed by GitHub
parent 7809b31b95
commit 29bf89db25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 162 additions and 0 deletions

View File

@ -45,6 +45,7 @@ motor>=3.3.1,<4
msal>=1.25.0,<2
mwparserfromhell>=0.6.4,<0.7
mwxml>=0.3.3,<0.4
networkx>=3.2.1,<4
newspaper3k>=0.2.8,<0.3
numexpr>=2.8.6,<3
nvidia-riva-client>=2.14.0,<3

View File

@ -0,0 +1,84 @@
"""Utilities for using Graph Vector Stores with networkx."""
import typing
from langchain_core.documents import Document
from langchain_community.graph_vectorstores.links import get_links
if typing.TYPE_CHECKING:
import networkx as nx
def documents_to_networkx(
documents: typing.Iterable[Document],
*,
tag_nodes: bool = True,
) -> "nx.DiGraph":
"""Return the networkx directed graph corresponding to the documents.
Args:
documents: The documents to convenrt to networkx.
tag_nodes: If `True`, each tag will be rendered as a node, with edges
to/from the corresponding documents. If `False`, edges will be
between documents, with a label corresponding to the tag(s)
connecting them.
"""
import networkx as nx
graph = nx.DiGraph()
tag_ids: typing.Dict[typing.Tuple[str, str], str] = {}
tag_labels: typing.Dict[str, str] = {}
documents_by_incoming: typing.Dict[str, typing.Set[str]] = {}
# First pass:
# - Register tag IDs for each unique (kind, tag).
# - If rendering tag nodes, add them to the graph.
# - If not rendering tag nodes, create a dictionary of documents by incoming tags.
for document in documents:
if document.id is None:
raise ValueError(f"Illegal graph document without ID: {document}")
for link in get_links(document):
tag_key = (link.kind, link.tag)
tag_id = tag_ids.get(tag_key)
if tag_id is None:
tag_id = f"tag_{len(tag_ids)}"
tag_ids[tag_key] = tag_id
if tag_nodes:
graph.add_node(tag_id, label=f"{link.kind}:{link.tag}")
if not tag_nodes and (link.direction == "in" or link.direction == "bidir"):
tag_labels[tag_id] = f"{link.kind}:{link.tag}"
documents_by_incoming.setdefault(tag_id, set()).add(document.id)
# Second pass:
# - Render document nodes
# - If rendering tag nodes, render edges to/from documents and tag nodes.
# - If not rendering tag nodes, render edges to/from documents based on tags.
for document in documents:
graph.add_node(document.id, text=document.page_content)
targets: typing.Dict[str, typing.List[str]] = {}
for link in get_links(document):
tag_id = tag_ids[(link.kind, link.tag)]
if tag_nodes:
if link.direction == "in" or link.direction == "bidir":
graph.add_edge(tag_id, document.id)
if link.direction == "out" or link.direction == "bidir":
graph.add_edge(document.id, tag_id)
else:
if link.direction == "out" or link.direction == "bidir":
label = tag_labels[tag_id]
for target in documents_by_incoming[tag_id]:
if target != document.id:
targets.setdefault(target, []).append(label)
# Avoid a multigraph by collecting the list of labels for each edge.
if not tag_nodes:
for target, labels in targets.items():
graph.add_edge(document.id, target, label=str(labels))
return graph

View File

@ -0,0 +1,77 @@
import pytest
from langchain_core.documents import Document
from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link
from langchain_community.graph_vectorstores.networkx import documents_to_networkx
@pytest.mark.requires("networkx")
def test_documents_to_networkx() -> None:
import networkx as nx
doc1 = Document(
id="a",
page_content="some content",
metadata={
METADATA_LINKS_KEY: [
Link.incoming("href", "a"),
Link.bidir("kw", "foo"),
]
},
)
doc2 = Document(
id="b",
page_content="<some\n more content>",
metadata={
METADATA_LINKS_KEY: [
Link.incoming("href", "b"),
Link.outgoing("href", "a"),
Link.bidir("kw", "foo"),
Link.bidir("kw", "bar"),
]
},
)
graph_with_tags = documents_to_networkx([doc1, doc2], tag_nodes=True)
link_data = nx.node_link_data(graph_with_tags)
assert link_data["directed"]
assert not link_data["multigraph"]
link_data["nodes"].sort(key=lambda n: n["id"])
assert link_data["nodes"] == [
{"id": "a", "text": "some content"},
{"id": "b", "text": "<some\n more content>"},
{"id": "tag_0", "label": "href:a"},
{"id": "tag_1", "label": "kw:foo"},
{"id": "tag_2", "label": "href:b"},
{"id": "tag_3", "label": "kw:bar"},
]
link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
assert link_data["links"] == [
{"source": "a", "target": "tag_1"},
{"source": "b", "target": "tag_0"},
{"source": "b", "target": "tag_1"},
{"source": "b", "target": "tag_3"},
{"source": "tag_0", "target": "a"},
{"source": "tag_1", "target": "a"},
{"source": "tag_1", "target": "b"},
{"source": "tag_2", "target": "b"},
{"source": "tag_3", "target": "b"},
]
graph_without_tags = documents_to_networkx([doc1, doc2], tag_nodes=False)
link_data = nx.node_link_data(graph_without_tags)
assert link_data["directed"]
assert not link_data["multigraph"]
link_data["nodes"].sort(key=lambda n: n["id"])
assert link_data["nodes"] == [
{"id": "a", "text": "some content"},
{"id": "b", "text": "<some\n more content>"},
]
link_data["links"].sort(key=lambda n: (n["source"], n["target"]))
assert link_data["links"] == [
{"source": "a", "target": "b", "label": "['kw:foo']"},
{"source": "b", "target": "a", "label": "['href:a', 'kw:foo']"},
]