mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
community[patch]: Add an option for indexed generic label when import neo4j graph documents (#18122)
Current implementation doesn't have an indexed property that would optimize the import. I have added a `baseEntityLabel` parameter that allows you to add a secondary node label, which has an indexed id `property`. By default, the behaviour is identical to previous version. Since multi-labeled nodes are terrible for text2cypher, I removed the secondary label from schema representation object and string, which is used in text2cypher.
This commit is contained in:
parent
e6e60e2492
commit
f6bfb969ba
@ -5,10 +5,13 @@ from langchain_core.utils import get_from_env
|
|||||||
from langchain_community.graphs.graph_document import GraphDocument
|
from langchain_community.graphs.graph_document import GraphDocument
|
||||||
from langchain_community.graphs.graph_store import GraphStore
|
from langchain_community.graphs.graph_store import GraphStore
|
||||||
|
|
||||||
|
BASE_ENTITY_LABEL = "__Entity__"
|
||||||
|
|
||||||
node_properties_query = """
|
node_properties_query = """
|
||||||
CALL apoc.meta.data()
|
CALL apoc.meta.data()
|
||||||
YIELD label, other, elementType, type, property
|
YIELD label, other, elementType, type, property
|
||||||
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
|
WHERE NOT type = "RELATIONSHIP" AND elementType = "node"
|
||||||
|
AND NOT label IN [$BASE_ENTITY_LABEL]
|
||||||
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
|
WITH label AS nodeLabels, collect({property:property, type:type}) AS properties
|
||||||
RETURN {labels: nodeLabels, properties: properties} AS output
|
RETURN {labels: nodeLabels, properties: properties} AS output
|
||||||
|
|
||||||
@ -27,9 +30,18 @@ CALL apoc.meta.data()
|
|||||||
YIELD label, other, elementType, type, property
|
YIELD label, other, elementType, type, property
|
||||||
WHERE type = "RELATIONSHIP" AND elementType = "node"
|
WHERE type = "RELATIONSHIP" AND elementType = "node"
|
||||||
UNWIND other AS other_node
|
UNWIND other AS other_node
|
||||||
|
WITH * WHERE NOT label IN [$BASE_ENTITY_LABEL]
|
||||||
|
AND NOT other_node IN [$BASE_ENTITY_LABEL]
|
||||||
RETURN {start: label, type: property, end: toString(other_node)} AS output
|
RETURN {start: label, type: property, end: toString(other_node)} AS output
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
include_docs_query = (
|
||||||
|
"CREATE (d:Document) "
|
||||||
|
"SET d.text = $document.page_content "
|
||||||
|
"SET d += $document.metadata "
|
||||||
|
"WITH d "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
|
def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Sanitize the input dictionary.
|
"""Sanitize the input dictionary.
|
||||||
@ -63,6 +75,53 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
return new_dict
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _get_node_import_query(baseEntityLabel: bool, include_source: bool) -> str:
|
||||||
|
if baseEntityLabel:
|
||||||
|
return (
|
||||||
|
f"{include_docs_query if include_source else ''}"
|
||||||
|
"UNWIND $data AS row "
|
||||||
|
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.id}}) "
|
||||||
|
"SET source += row.properties "
|
||||||
|
f"{'MERGE (d)-[:MENTIONS]->(source) ' if include_source else ''}"
|
||||||
|
"WITH source, row "
|
||||||
|
"CALL apoc.create.addLabels( source, [row.type] ) YIELD node "
|
||||||
|
"RETURN distinct 'done' AS result"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
f"{include_docs_query if include_source else ''}"
|
||||||
|
"UNWIND $data AS row "
|
||||||
|
"CALL apoc.merge.node([row.type], {id: row.id}, "
|
||||||
|
"row.properties, {}) YIELD node "
|
||||||
|
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
|
||||||
|
"RETURN distinct 'done' AS result"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rel_import_query(baseEntityLabel: bool) -> str:
|
||||||
|
if baseEntityLabel:
|
||||||
|
return (
|
||||||
|
"UNWIND $data AS row "
|
||||||
|
f"MERGE (source:`{BASE_ENTITY_LABEL}` {{id: row.source}}) "
|
||||||
|
f"MERGE (target:`{BASE_ENTITY_LABEL}` {{id: row.target}}) "
|
||||||
|
"WITH source, target, row "
|
||||||
|
"CALL apoc.merge.relationship(source, row.type, "
|
||||||
|
"{}, row.properties, target) YIELD rel "
|
||||||
|
"RETURN distinct 'done'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
"UNWIND $data AS row "
|
||||||
|
"CALL apoc.merge.node([row.source_label], {id: row.source},"
|
||||||
|
"{}, {}) YIELD node as source "
|
||||||
|
"CALL apoc.merge.node([row.target_label], {id: row.target},"
|
||||||
|
"{}, {}) YIELD node as target "
|
||||||
|
"CALL apoc.merge.relationship(source, row.type, "
|
||||||
|
"{}, row.properties, target) YIELD rel "
|
||||||
|
"RETURN distinct 'done'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Neo4jGraph(GraphStore):
|
class Neo4jGraph(GraphStore):
|
||||||
"""Neo4j database wrapper for various graph operations.
|
"""Neo4j database wrapper for various graph operations.
|
||||||
|
|
||||||
@ -173,14 +232,42 @@ class Neo4jGraph(GraphStore):
|
|||||||
"""
|
"""
|
||||||
Refreshes the Neo4j graph schema information.
|
Refreshes the Neo4j graph schema information.
|
||||||
"""
|
"""
|
||||||
node_properties = [el["output"] for el in self.query(node_properties_query)]
|
from neo4j.exceptions import ClientError
|
||||||
rel_properties = [el["output"] for el in self.query(rel_properties_query)]
|
|
||||||
relationships = [el["output"] for el in self.query(rel_query)]
|
node_properties = [
|
||||||
|
el["output"]
|
||||||
|
for el in self.query(
|
||||||
|
node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
rel_properties = [
|
||||||
|
el["output"]
|
||||||
|
for el in self.query(
|
||||||
|
rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
relationships = [
|
||||||
|
el["output"]
|
||||||
|
for el in self.query(
|
||||||
|
rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get constraints & indexes
|
||||||
|
try:
|
||||||
|
constraint = self.query("SHOW CONSTRAINTS")
|
||||||
|
index = self.query("SHOW INDEXES YIELD *")
|
||||||
|
except (
|
||||||
|
ClientError
|
||||||
|
): # Read-only user might not have access to schema information
|
||||||
|
constraint = []
|
||||||
|
index = []
|
||||||
|
|
||||||
self.structured_schema = {
|
self.structured_schema = {
|
||||||
"node_props": {el["labels"]: el["properties"] for el in node_properties},
|
"node_props": {el["labels"]: el["properties"] for el in node_properties},
|
||||||
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
|
"rel_props": {el["type"]: el["properties"] for el in rel_properties},
|
||||||
"relationships": relationships,
|
"relationships": relationships,
|
||||||
|
"metadata": {"constraint": constraint, "index": index},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Format node properties
|
# Format node properties
|
||||||
@ -216,28 +303,51 @@ class Neo4jGraph(GraphStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def add_graph_documents(
|
def add_graph_documents(
|
||||||
self, graph_documents: List[GraphDocument], include_source: bool = False
|
self,
|
||||||
|
graph_documents: List[GraphDocument],
|
||||||
|
include_source: bool = False,
|
||||||
|
baseEntityLabel: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Take GraphDocument as input as uses it to construct a graph.
|
This method constructs nodes and relationships in the graph based on the
|
||||||
|
provided GraphDocument objects.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- graph_documents (List[GraphDocument]): A list of GraphDocument objects
|
||||||
|
that contain the nodes and relationships to be added to the graph. Each
|
||||||
|
GraphDocument should encapsulate the structure of part of the graph,
|
||||||
|
including nodes, relationships, and the source document information.
|
||||||
|
- include_source (bool, optional): If True, stores the source document
|
||||||
|
and links it to nodes in the graph using the MENTIONS relationship.
|
||||||
|
This is useful for tracing back the origin of data. Defaults to False.
|
||||||
|
- baseEntityLabel (bool, optional): If True, each newly created node
|
||||||
|
gets a secondary __Entity__ label, which is indexed and improves import
|
||||||
|
speed and performance. Defaults to False.
|
||||||
"""
|
"""
|
||||||
for document in graph_documents:
|
if baseEntityLabel: # Check if constraint already exists
|
||||||
include_docs_query = (
|
constraint_exists = any(
|
||||||
"CREATE (d:Document) "
|
[
|
||||||
"SET d.text = $document.page_content "
|
el["labelsOrTypes"] == [BASE_ENTITY_LABEL]
|
||||||
"SET d += $document.metadata "
|
and el["properties"] == ["id"]
|
||||||
"WITH d "
|
for el in self.structured_schema.get("metadata", {}).get(
|
||||||
|
"constraint"
|
||||||
|
)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
if not constraint_exists:
|
||||||
|
# Create constraint
|
||||||
|
self.query(
|
||||||
|
f"CREATE CONSTRAINT IF NOT EXISTS FOR (b:{BASE_ENTITY_LABEL}) "
|
||||||
|
"REQUIRE b.id IS UNIQUE;"
|
||||||
|
)
|
||||||
|
self.refresh_schema() # Refresh constraint information
|
||||||
|
|
||||||
|
node_import_query = _get_node_import_query(baseEntityLabel, include_source)
|
||||||
|
rel_import_query = _get_rel_import_query(baseEntityLabel)
|
||||||
|
for document in graph_documents:
|
||||||
# Import nodes
|
# Import nodes
|
||||||
self.query(
|
self.query(
|
||||||
(
|
node_import_query,
|
||||||
f"{include_docs_query if include_source else ''}"
|
|
||||||
"UNWIND $data AS row "
|
|
||||||
"CALL apoc.merge.node([row.type], {id: row.id}, "
|
|
||||||
"row.properties, {}) YIELD node "
|
|
||||||
f"{'MERGE (d)-[:MENTIONS]->(node) ' if include_source else ''}"
|
|
||||||
"RETURN distinct 'done' AS result"
|
|
||||||
),
|
|
||||||
{
|
{
|
||||||
"data": [el.__dict__ for el in document.nodes],
|
"data": [el.__dict__ for el in document.nodes],
|
||||||
"document": document.source.__dict__,
|
"document": document.source.__dict__,
|
||||||
@ -245,14 +355,7 @@ class Neo4jGraph(GraphStore):
|
|||||||
)
|
)
|
||||||
# Import relationships
|
# Import relationships
|
||||||
self.query(
|
self.query(
|
||||||
"UNWIND $data AS row "
|
rel_import_query,
|
||||||
"CALL apoc.merge.node([row.source_label], {id: row.source},"
|
|
||||||
"{}, {}) YIELD node as source "
|
|
||||||
"CALL apoc.merge.node([row.target_label], {id: row.target},"
|
|
||||||
"{}, {}) YIELD node as target "
|
|
||||||
"CALL apoc.merge.relationship(source, row.type, "
|
|
||||||
"{}, row.properties, target) YIELD rel "
|
|
||||||
"RETURN distinct 'done'",
|
|
||||||
{
|
{
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
|
@ -1,12 +1,30 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain_community.graphs import Neo4jGraph
|
from langchain_community.graphs import Neo4jGraph
|
||||||
|
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||||
from langchain_community.graphs.neo4j_graph import (
|
from langchain_community.graphs.neo4j_graph import (
|
||||||
|
BASE_ENTITY_LABEL,
|
||||||
node_properties_query,
|
node_properties_query,
|
||||||
rel_properties_query,
|
rel_properties_query,
|
||||||
rel_query,
|
rel_query,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
test_data = [
|
||||||
|
GraphDocument(
|
||||||
|
nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")],
|
||||||
|
relationships=[
|
||||||
|
Relationship(
|
||||||
|
source=Node(id="foo", type="foo"),
|
||||||
|
target=Node(id="bar", type="bar"),
|
||||||
|
type="REL",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
source=Document(page_content="source document"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_cypher_return_correct_schema() -> None:
|
def test_cypher_return_correct_schema() -> None:
|
||||||
"""Test that chain returns direct results."""
|
"""Test that chain returns direct results."""
|
||||||
@ -37,9 +55,15 @@ def test_cypher_return_correct_schema() -> None:
|
|||||||
# Refresh schema information
|
# Refresh schema information
|
||||||
graph.refresh_schema()
|
graph.refresh_schema()
|
||||||
|
|
||||||
node_properties = graph.query(node_properties_query)
|
node_properties = graph.query(
|
||||||
relationships_properties = graph.query(rel_properties_query)
|
node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||||
relationships = graph.query(rel_query)
|
)
|
||||||
|
relationships_properties = graph.query(
|
||||||
|
rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||||
|
)
|
||||||
|
relationships = graph.query(
|
||||||
|
rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL}
|
||||||
|
)
|
||||||
|
|
||||||
expected_node_properties = [
|
expected_node_properties = [
|
||||||
{
|
{
|
||||||
@ -116,3 +140,112 @@ def test_neo4j_sanitize_values() -> None:
|
|||||||
|
|
||||||
output = graph.query("RETURN range(0,130,1) AS result")
|
output = graph.query("RETURN range(0,130,1) AS result")
|
||||||
assert output == [{}]
|
assert output == [{}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_neo4j_add_data() -> None:
|
||||||
|
"""Test that neo4j correctly import graph document."""
|
||||||
|
url = os.environ.get("NEO4J_URI")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Remove all constraints
|
||||||
|
graph.query("CALL apoc.schema.assert({}, {})")
|
||||||
|
graph.refresh_schema()
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.add_graph_documents(test_data)
|
||||||
|
output = graph.query(
|
||||||
|
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label"
|
||||||
|
)
|
||||||
|
assert output == [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}]
|
||||||
|
assert graph.structured_schema["metadata"]["constraint"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_neo4j_add_data_source() -> None:
|
||||||
|
"""Test that neo4j correctly import graph document with source."""
|
||||||
|
url = os.environ.get("NEO4J_URI")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Remove all constraints
|
||||||
|
graph.query("CALL apoc.schema.assert({}, {})")
|
||||||
|
graph.refresh_schema()
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.add_graph_documents(test_data, include_source=True)
|
||||||
|
output = graph.query(
|
||||||
|
"MATCH (n) RETURN labels(n) AS label, count(*) AS count ORDER BY label"
|
||||||
|
)
|
||||||
|
assert output == [
|
||||||
|
{"label": ["Document"], "count": 1},
|
||||||
|
{"label": ["bar"], "count": 1},
|
||||||
|
{"label": ["foo"], "count": 1},
|
||||||
|
]
|
||||||
|
assert graph.structured_schema["metadata"]["constraint"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_neo4j_add_data_base() -> None:
|
||||||
|
"""Test that neo4j correctly import graph document with base_entity."""
|
||||||
|
url = os.environ.get("NEO4J_URI")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Remove all constraints
|
||||||
|
graph.query("CALL apoc.schema.assert({}, {})")
|
||||||
|
graph.refresh_schema()
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.add_graph_documents(test_data, baseEntityLabel=True)
|
||||||
|
output = graph.query(
|
||||||
|
"MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, "
|
||||||
|
"count(*) AS count ORDER BY label"
|
||||||
|
)
|
||||||
|
assert output == [
|
||||||
|
{"label": [BASE_ENTITY_LABEL, "bar"], "count": 1},
|
||||||
|
{"label": [BASE_ENTITY_LABEL, "foo"], "count": 1},
|
||||||
|
]
|
||||||
|
assert graph.structured_schema["metadata"]["constraint"] != []
|
||||||
|
|
||||||
|
|
||||||
|
def test_neo4j_add_data_base_source() -> None:
|
||||||
|
"""Test that neo4j correctly import graph document with base_entity and source."""
|
||||||
|
url = os.environ.get("NEO4J_URI")
|
||||||
|
username = os.environ.get("NEO4J_USERNAME")
|
||||||
|
password = os.environ.get("NEO4J_PASSWORD")
|
||||||
|
assert url is not None
|
||||||
|
assert username is not None
|
||||||
|
assert password is not None
|
||||||
|
|
||||||
|
graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True)
|
||||||
|
# Delete all nodes in the graph
|
||||||
|
graph.query("MATCH (n) DETACH DELETE n")
|
||||||
|
# Remove all constraints
|
||||||
|
graph.query("CALL apoc.schema.assert({}, {})")
|
||||||
|
graph.refresh_schema()
|
||||||
|
# Create two nodes and a relationship
|
||||||
|
graph.add_graph_documents(test_data, baseEntityLabel=True, include_source=True)
|
||||||
|
output = graph.query(
|
||||||
|
"MATCH (n) RETURN apoc.coll.sort(labels(n)) AS label, "
|
||||||
|
"count(*) AS count ORDER BY label"
|
||||||
|
)
|
||||||
|
assert output == [
|
||||||
|
{"label": ["Document"], "count": 1},
|
||||||
|
{"label": [BASE_ENTITY_LABEL, "bar"], "count": 1},
|
||||||
|
{"label": [BASE_ENTITY_LABEL, "foo"], "count": 1},
|
||||||
|
]
|
||||||
|
assert graph.structured_schema["metadata"]["constraint"] != []
|
||||||
|
Loading…
Reference in New Issue
Block a user