diff --git a/libs/community/langchain_community/graphs/neo4j_graph.py b/libs/community/langchain_community/graphs/neo4j_graph.py index c8022d1db0b..afb32e17a69 100644 --- a/libs/community/langchain_community/graphs/neo4j_graph.py +++ b/libs/community/langchain_community/graphs/neo4j_graph.py @@ -6,12 +6,14 @@ from langchain_community.graphs.graph_document import GraphDocument from langchain_community.graphs.graph_store import GraphStore BASE_ENTITY_LABEL = "__Entity__" +EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"] +EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"] node_properties_query = """ CALL apoc.meta.data() YIELD label, other, elementType, type, property WHERE NOT type = "RELATIONSHIP" AND elementType = "node" - AND NOT label IN [$BASE_ENTITY_LABEL] + AND NOT label IN $EXCLUDED_LABELS WITH label AS nodeLabels, collect({property:property, type:type}) AS properties RETURN {labels: nodeLabels, properties: properties} AS output @@ -21,6 +23,7 @@ rel_properties_query = """ CALL apoc.meta.data() YIELD label, other, elementType, type, property WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" + AND NOT label in $EXCLUDED_LABELS WITH label AS nodeLabels, collect({property:property, type:type}) AS properties RETURN {type: nodeLabels, properties: properties} AS output """ @@ -30,8 +33,8 @@ CALL apoc.meta.data() YIELD label, other, elementType, type, property WHERE type = "RELATIONSHIP" AND elementType = "node" UNWIND other AS other_node -WITH * WHERE NOT label IN [$BASE_ENTITY_LABEL] - AND NOT other_node IN [$BASE_ENTITY_LABEL] +WITH * WHERE NOT label IN $EXCLUDED_LABELS + AND NOT other_node IN $EXCLUDED_LABELS RETURN {start: label, type: property, end: toString(other_node)} AS output """ @@ -243,19 +246,21 @@ class Neo4jGraph(GraphStore): node_properties = [ el["output"] for el in self.query( - node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + node_properties_query, + params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, ) ] rel_properties = [ el["output"] for el in self.query( - rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + rel_properties_query, params={"EXCLUDED_LABELS": EXCLUDED_RELS} ) ] relationships = [ el["output"] for el in self.query( - rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + rel_query, + params={"EXCLUDED_LABELS": EXCLUDED_LABELS + [BASE_ENTITY_LABEL]}, ) ] diff --git a/libs/community/tests/integration_tests/graphs/test_neo4j.py b/libs/community/tests/integration_tests/graphs/test_neo4j.py index 1fcb4b7fbb6..50e0a9f7244 100644 --- a/libs/community/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/community/tests/integration_tests/graphs/test_neo4j.py @@ -56,13 +56,13 @@ def test_cypher_return_correct_schema() -> None: graph.refresh_schema() node_properties = graph.query( - node_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + node_properties_query, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} ) relationships_properties = graph.query( - rel_properties_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + rel_properties_query, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} ) relationships = graph.query( - rel_query, params={"BASE_ENTITY_LABEL": BASE_ENTITY_LABEL} + rel_query, params={"EXCLUDED_LABELS": [BASE_ENTITY_LABEL]} ) expected_node_properties = [ @@ -249,3 +249,27 @@ def test_neo4j_add_data_base_source() -> None: {"label": [BASE_ENTITY_LABEL, "foo"], "count": 1}, ] assert graph.structured_schema["metadata"]["constraint"] != [] + + +def test_neo4j_filtering_labels() -> None: + """Test that neo4j correctly filters excluded labels.""" + url = os.environ.get("NEO4J_URI") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph(url=url, username=username, password=password, sanitize=True) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Remove all constraints + graph.query("CALL apoc.schema.assert({}, {})") + graph.query( + "CREATE (:`_Bloom_Scene_`)-[:_Bloom_HAS_SCENE_]->(:`_Bloom_Perspective_`)" + ) + graph.refresh_schema() + + # Assert both are empty + assert graph.structured_schema["node_props"] == {} + assert graph.structured_schema["relationships"] == []