mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
community[patch]: Add an option for indexed generic label when import neo4j graph documents (#18122)
Current implementation doesn't have an indexed property that would optimize the import. I have added a `baseEntityLabel` parameter that allows you to add a secondary node label, which has an indexed id `property`. By default, the behaviour is identical to previous version. Since multi-labeled nodes are terrible for text2cypher, I removed the secondary label from schema representation object and string, which is used in text2cypher.
This commit is contained in:
parent
e6e60e2492
commit
f6bfb969ba
@ -5,10 +5,13 @@ from langchain_core.utils import get_from_env
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
|
||||
BASE_ENTITY_LABEL = "__Entity__"
|
||||
|
||||
node_properties_query = """
|
||||
CALL apoc.meta.data()
|
||||
YIELD label, other, elementType, type, property
|
||||
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
|
||||
AND NOT label IN [$BASE_ENTITY_LABEL]
|
||||
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
|
||||
RETURN {labels: nodeLabels, properties: properties} AS output
|
||||
|
||||
@ -27,9 +30,18 @@ CALL apoc.meta.data()
|
||||
YIELD label, other, elementType, type, property
|
||||
WHERE type = "RELATIONSHIP" AND elementType = "node"
|
||||
UNWIND other AS other_node
|
||||
WITH * WHERE NOT label IN [$BASE_ENTITY_LABEL]
|
||||
AND NOT other_node IN [$BASE_ENTITY_LABEL]
|
||||
RETURN {start: label, type: property, end: toString(other_node)} AS output
|
||||
"""
|
||||
|
||||
include_docs_query = (
|
||||
"CREATE (d:Document) "
|
||||
"SET d.text = $document.page_content "
|
||||
"SET d += $document.metadata "
|
||||
"WITH d "
|
||||
)
|
||||
|
||||
|
||||
def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize the input dictionary.
|
||||
@ -63,6 +75,53 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return new_dict
|
||||
|
||||
|
||||
def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str:
|
||||
if baseEntityLabel:
|
||||
return (
|
||||
f"{include_docs_query if include_source else ''}"
|
||||
"UNWIND $data AS row "
|
||||
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.id}}) "
|
||||
"SET source += row.properties "
|
||||
f"{'MERGE (d)-[:MENTIONS]->(source) ' if include_source else ''}"
|
||||
"WITH source, row "
|
||||
"CALL apoc.create.addLabels( source, [row.type] ) YIELD node "
|
||||
"RETURN distinct 'done' AS result"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"{include_docs_query if include_source else ''}"
|
||||
"UNWIND $data AS row "
|
||||
"CALL apoc.merge.node([row.type], {id: row.id}, "
|
||||
"row.properties, {}) YIELD node "
|
||||
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
|
||||
"RETURN distinct 'done' AS result"
|
||||
)
|
||||
|
||||
|
||||
def _get_rel_import_query(baseEntityLabel: bool) -> str:
|
||||
if baseEntityLabel:
|
||||
return (
|
||||
"UNWIND $data AS row "
|
||||
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.source}}) "
|
||||
f"MERGE (target:`{BASE_ENTITY_LABEL}` {{id: row.target}}) "
|
||||
"WITH source, target, row "
|
||||
"CALL apoc.merge.relationship(source, row.type, "
|
||||
"{}, row.properties, target) YIELD rel "
|
||||
"RETURN distinct 'done'"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
"UNWIND $data AS row "
|
||||
"CALL apoc.merge.node([row.source_label], {id: row.source},"
|
||||
"{}, {}) YIELD node as source "
|
||||
"CALL apoc.merge.node([row.target_label], {id: row.target},"
|
||||
"{}, {}) YIELD node as target "
|
||||
"CALL apoc.merge.relationship(source, row.type, "
|
||||
"{}, row.properties, target) YIELD rel "
|
||||
"RETURN distinct 'done'"
|
||||
)
|
||||
|
||||
|
||||
class Neo4jGraph(GraphStore):
|
||||
"""Neo4j database wrapper for various graph operations.
|
||||
|
||||
@ -173,14 +232,42 @@ class Neo4jGraph(GraphStore):
|
||||
"""
|
||||
Refreshes the Neo4j graph schema information.
|
||||
"""
|
||||
node_properties = [el["output"] for el in self.query(node_properties_query)]
|
||||
rel_properties = [el["output"] for el in self.query(rel_properties_query)]
|
||||
relationships = [el["output"] for el in self.query(rel_query)]
|
||||
from neo4j.exceptions import ClientError
|
||||
|
||||
node_properties = [
|
||||
el["output"]
|
||||
for el in self.query(
|
||||
node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||
)
|
||||
]
|
||||
rel_properties = [
|
||||
el["output"]
|
||||
for el in self.query(
|
||||
rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||
)
|
||||
]
|
||||
relationships = [
|
||||
el["output"]
|
||||
for el in self.query(
|
||||
rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||
)
|
||||
]
|
||||
|
||||
# Get constraints & indexes
|
||||
try:
|
||||
constraint = self.query("SHOW CONSTRAINTS")
|
||||
index = self.query("SHOW INDEXES YIELD *")
|
||||
except (
|
||||
ClientError
|
||||
): # Read-only user might not have access to schema information
|
||||
constraint = []
|
||||
index = []
|
||||
|
||||
self.structured_schema = {
|
||||
"node_props": {el["labels"]: el["properties"] for el in node_properties},
|
||||
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
|
||||
"relationships": relationships,
|
||||
"metadata": {"constraint": constraint, "index": index},
|
||||
}
|
||||
|
||||
# Format node properties
|
||||
@ -216,28 +303,51 @@ class Neo4jGraph(GraphStore):
|
||||
)
|
||||
|
||||
def add_graph_documents(
|
||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
||||
self,
|
||||
graph_documents: List[GraphDocument],
|
||||
include_source: bool = False,
|
||||
baseEntityLabel: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Take GraphDocument as input as uses it to construct a graph.
|
||||
This method constructs nodes and relationships in the graph based on the
|
||||
provided GraphDocument objects.
|
||||
|
||||
Parameters:
|
||||
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
|
||||
that contain the nodes and relationships to be added to the graph. Each
|
||||
GraphDocument should encapsulate the structure of part of the graph,
|
||||
including nodes, relationships, and the source document information.
|
||||
- include_source (bool, optional): If True, stores the source document
|
||||
and links it to nodes in the graph using the MENTIONS relationship.
|
||||
This is useful for tracing back the origin of data. Defaults to False.
|
||||
- baseEntityLabel (bool, optional): If True, each newly created node
|
||||
gets a secondary __Entity__ label, which is indexed and improves import
|
||||
speed and performance. Defaults to False.
|
||||
"""
|
||||
for document in graph_documents:
|
||||
include_docs_query = (
|
||||
"CREATE (d:Document) "
|
||||
"SET d.text = $document.page_content "
|
||||
"SET d += $document.metadata "
|
||||
"WITH d "
|
||||
if baseEntityLabel: # Check if constraint already exists
|
||||
constraint_exists = any(
|
||||
[
|
||||
el["labelsOrTypes"] == [BASE_ENTITY_LABEL]
|
||||
and el["properties"] == ["id"]
|
||||
for el in self.structured_schema.get("metadata", {}).get(
|
||||
"constraint"
|
||||
)
|
||||
]
|
||||
)
|
||||
if not constraint_exists:
|
||||
# Create constraint
|
||||
self.query(
|
||||
f"CREATE CONSTRAINT IF NOT EXISTS FOR (b:{BASE_ENTITY_LABEL}) "
|
||||
"REQUIRE b.id IS UNIQUE;"
|
||||
)
|
||||
self.refresh_schema() # Refresh constraint information
|
||||
|
||||
node_import_query = _get_node_import_query(baseEntityLabel, include_source)
|
||||
rel_import_query = _get_rel_import_query(baseEntityLabel)
|
||||
for document in graph_documents:
|
||||
# Import nodes
|
||||
self.query(
|
||||
(
|
||||
f"{include_docs_query if include_source else ''}"
|
||||
"UNWIND $data AS row "
|
||||
"CALL apoc.merge.node([row.type], {id: row.id}, "
|
||||
"row.properties, {}) YIELD node "
|
||||
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
|
||||
"RETURN distinct 'done' AS result"
|
||||
),
|
||||
node_import_query,
|
||||
{
|
||||
"data": [el.__dict__ for el in document.nodes],
|
||||
"document": document.source.__dict__,
|
||||
@ -245,14 +355,7 @@ class Neo4jGraph(GraphStore):
|
||||
)
|
||||
# Import relationships
|
||||
self.query(
|
||||
"UNWIND $data AS row "
|
||||
"CALL apoc.merge.node([row.source_label], {id: row.source},"
|
||||
"{}, {}) YIELD node as source "
|
||||
"CALL apoc.merge.node([row.target_label], {id: row.target},"
|
||||
"{}, {}) YIELD node as target "
|
||||
"CALL apoc.merge.relationship(source, row.type, "
|
||||
"{}, row.properties, target) YIELD rel "
|
||||
"RETURN distinct 'done'",
|
||||
rel_import_query,
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
|
@ -1,12 +1,30 @@
|
||||
import os
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_community.graphs.neo4j_graph import (
|
||||
BASE_ENTITY_LABEL,
|
||||
node_properties_query,
|
||||
rel_properties_query,
|
||||
rel_query,
|
||||
)
|
||||
|
||||
test_data = [
|
||||
GraphDocument(
|
||||
nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")],
|
||||
relationships=[
|
||||
Relationship(
|
||||
source=Node(id="foo", type="foo"),
|
||||
target=Node(id="bar", type="bar"),
|
||||
type="REL",
|
||||
)
|
||||
],
|
||||
source=Document(page_content="source document"),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_cypher_return_correct_schema() -> None:
|
||||
"""Test that chain returns direct results."""
|
||||
@ -37,9 +55,15 @@ def test_cypher_return_correct_schema() -> None:
|
||||
# Refresh schema information
|
||||
graph.refresh_schema()
|
||||
|
||||
node_properties = graph.query(node_properties_query)
|
||||
relationships_properties = graph.query(rel_properties_query)
|
||||
relationships = graph.query(rel_query)
|
||||
node_properties = graph.query(
|
||||
node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||
)
|
||||
relationships_properties = graph.query(
|
||||
rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||
)
|
||||
relationships = graph.query(
|
||||
rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||
)
|
||||
|
||||
expected_node_properties = [
|
||||
{
|
||||
@ -116,3 +140,112 @@ def test_neo4j_sanitize_values() -> None:
|
||||
|
||||
output = graph.query("RETURN range(0,130,1) AS result")
|
||||
assert output == [{}]
|
||||
|
||||
|
||||
def test_neo4j_add_data() -> None:
|
||||
"""Test that neo4j correctly import graph document."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Remove all constraints
|
||||
graph.query("CALL apoc.schema.assert({}, {})")
|
||||
graph.refresh_schema()
|
||||
# Create two nodes and a relationship
|
||||
graph.add_graph_documents(test_data)
|
||||
output = graph.query(
|
||||
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label"
|
||||
)
|
||||
assert output == [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}]
|
||||
assert graph.structured_schema["metadata"]["constraint"] == []
|
||||
|
||||
|
||||
def test_neo4j_add_data_source() -> None:
|
||||
"""Test that neo4j correctly import graph document with source."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Remove all constraints
|
||||
graph.query("CALL apoc.schema.assert({}, {})")
|
||||
graph.refresh_schema()
|
||||
# Create two nodes and a relationship
|
||||
graph.add_graph_documents(test_data, include_source=True)
|
||||
output = graph.query(
|
||||
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label"
|
||||
)
|
||||
assert output == [
|
||||
{"label": ["Document"], "count": 1},
|
||||
{"label": ["bar"], "count": 1},
|
||||
{"label": ["foo"], "count": 1},
|
||||
]
|
||||
assert graph.structured_schema["metadata"]["constraint"] == []
|
||||
|
||||
|
||||
def test_neo4j_add_data_base() -> None:
|
||||
"""Test that neo4j correctly import graph document with base_entity."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Remove all constraints
|
||||
graph.query("CALL apoc.schema.assert({}, {})")
|
||||
graph.refresh_schema()
|
||||
# Create two nodes and a relationship
|
||||
graph.add_graph_documents(test_data, baseEntityLabel=True)
|
||||
output = graph.query(
|
||||
"MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, "
|
||||
"count(*) AS count ORDER BY label"
|
||||
)
|
||||
assert output == [
|
||||
{"label": [BASE_ENTITY_LABEL, "bar"], "count": 1},
|
||||
{"label": [BASE_ENTITY_LABEL, "foo"], "count": 1},
|
||||
]
|
||||
assert graph.structured_schema["metadata"]["constraint"] != []
|
||||
|
||||
|
||||
def test_neo4j_add_data_base_source() -> None:
|
||||
"""Test that neo4j correctly import graph document with base_entity and source."""
|
||||
url = os.environ.get("NEO4J_URI")
|
||||
username = os.environ.get("NEO4J_USERNAME")
|
||||
password = os.environ.get("NEO4J_PASSWORD")
|
||||
assert url is not None
|
||||
assert username is not None
|
||||
assert password is not None
|
||||
|
||||
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||
# Delete all nodes in the graph
|
||||
graph.query("MATCH (n) DETACH DELETE n")
|
||||
# Remove all constraints
|
||||
graph.query("CALL apoc.schema.assert({}, {})")
|
||||
graph.refresh_schema()
|
||||
# Create two nodes and a relationship
|
||||
graph.add_graph_documents(test_data, baseEntityLabel=True, include_source=True)
|
||||
output = graph.query(
|
||||
"MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, "
|
||||
"count(*) AS count ORDER BY label"
|
||||
)
|
||||
assert output == [
|
||||
{"label": ["Document"], "count": 1},
|
||||
{"label": [BASE_ENTITY_LABEL, "bar"], "count": 1},
|
||||
{"label": [BASE_ENTITY_LABEL, "foo"], "count": 1},
|
||||
]
|
||||
assert graph.structured_schema["metadata"]["constraint"] != []
|
||||
|
Loading…
Reference in New Issue
Block a user