community[patch]: Neo4j enhanced schema (#20983)

Scan the database for example values and provide them to an LLM for
better inference of Text2cypher
This commit is contained in:
Tomaz Bratanic
2024-04-29 16:45:55 +02:00
committed by GitHub
parent dc70c23a11
commit 67428c4052
3 changed files with 423 additions and 92 deletions

View File

@@ -9,6 +9,11 @@ from langchain_community.graphs.graph_store import GraphStore
BASE_ENTITY_LABEL = "__Entity__"
EXCLUDED_LABELS = ["_Bloom_Perspective_", "_Bloom_Scene_"]
EXCLUDED_RELS = ["_Bloom_HAS_SCENE_"]
EXHAUSTIVE_SEARCH_LIMIT = 10000
LIST_LIMIT = 128
# Threshold for returning all available prop values in graph schema
DISTINCT_VALUE_LIMIT = 10
NL = "\n"
node_properties_query = """
CALL apoc.meta.data()
@@ -56,7 +61,6 @@ def value_sanitize(d: Any) -> Any:
results, can occupy significant context space and detract from
the LLM's performance by introducing unnecessary noise and cost.
"""
LIST_LIMIT = 128
if isinstance(d, dict):
new_dict = {}
for key, value in d.items():
@@ -135,6 +139,223 @@ def _get_rel_import_query(baseEntityLabel: bool) -> str:
)
def _enhanced_schema_cypher(
label_or_type: str,
properties: List[Dict[str, Any]],
exhaustive: bool,
is_relationship: bool = False,
) -> str:
if is_relationship:
match_clause = f"MATCH ()-[n:{label_or_type}]->()"
else:
match_clause = f"MATCH (n:{label_or_type})"
with_clauses = []
return_clauses = []
output_dict = {}
if exhaustive:
for prop in properties:
prop_name = prop["property"]
prop_type = prop["type"]
if prop_type == "STRING":
with_clauses.append(
(
f"collect(distinct substring(n.`{prop_name}`, 0, 50)) "
f"AS `{prop_name}_values`"
)
)
return_clauses.append(
(
f"values:`{prop_name}_values`[..{DISTINCT_VALUE_LIMIT}],"
f" distinct_count: size(`{prop_name}_values`)"
)
)
elif prop_type in ["INTEGER", "FLOAT", "DATE"]:
with_clauses.append(f"min(n.`{prop_name}`) AS `{prop_name}_min`")
with_clauses.append(f"max(n.`{prop_name}`) AS `{prop_name}_max`")
with_clauses.append(
f"count(distinct n.`{prop_name}`) AS `{prop_name}_distinct`"
)
return_clauses.append(
(
f"min: toString(`{prop_name}_min`), "
f"max: toString(`{prop_name}_max`), "
f"distinct_count: `{prop_name}_distinct`"
)
)
elif prop_type == "LIST":
with_clauses.append(
(
f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
)
)
return_clauses.append(
f"min_size: `{prop_name}_size_min`, "
f"max_size: `{prop_name}_size_max`"
)
output_dict[prop_name] = "{" + return_clauses.pop() + "}"
else:
# Just sample 5 random nodes
match_clause += " WITH n LIMIT 5"
for prop in properties:
prop_name = prop["property"]
prop_type = prop["type"]
if prop_type == "STRING":
with_clauses.append(
(
f"collect(distinct substring(n.`{prop_name}`, 0, 50)) "
f"AS `{prop_name}_values`"
)
)
return_clauses.append(f"values: `{prop_name}_values`")
elif prop_type in ["INTEGER", "FLOAT", "DATE"]:
with_clauses.append(
f"collect(distinct toString(n.`{prop_name}`)) "
f"AS `{prop_name}_values`"
)
return_clauses.append(f"values: `{prop_name}_values`")
elif prop_type == "LIST":
with_clauses.append(
(
f"min(size(n.`{prop_name}`)) AS `{prop_name}_size_min`, "
f"max(size(n.`{prop_name}`)) AS `{prop_name}_size_max`"
)
)
return_clauses.append(
f"min_size: `{prop_name}_size_min`,max_size: `{prop_name}_size_max`"
)
output_dict[prop_name] = "{" + return_clauses.pop() + "}"
with_clause = "WITH " + ",\n ".join(with_clauses)
return_clause = (
"RETURN {"
+ ", ".join(f"{k}: {v}" for k, v in output_dict.items())
+ "} AS output"
)
# Combine all parts of the Cypher query
cypher_query = "\n".join([match_clause, with_clause, return_clause])
return cypher_query
def _format_schema(schema: Dict, is_enhanced: bool) -> str:
formatted_node_props = []
formatted_rel_props = []
if is_enhanced:
# Enhanced formatting for nodes
for node_type, properties in schema["node_props"].items():
formatted_node_props.append(f"- **{node_type}**")
for prop in properties:
example = ""
if prop["type"] == "STRING":
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
example = (
f'Example: "{prop["values"][0].replace(NL, " ")}"'
if prop["values"]
else ""
)
else: # If less than 10 possible values return all
example = (
(
"Available options: "
f'{[el.replace(NL, " ") for el in prop["values"]]}'
)
if prop["values"]
else ""
)
elif prop["type"] in ["INTEGER", "FLOAT", "DATE"]:
if prop.get("min") is not None:
example = f'Min: {prop["min"]}, Max: {prop["max"]}'
else:
example = (
f'Example: "{prop["values"][0]}"' if prop["values"] else ""
)
elif prop["type"] == "LIST":
# Skip embeddings
if prop["min_size"] > LIST_LIMIT:
continue
example = (
f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
)
formatted_node_props.append(
f" - `{prop['property']}: {prop['type']}` {example}"
)
# Enhanced formatting for relationships
for rel_type, properties in schema["rel_props"].items():
formatted_rel_props.append(f"- **{rel_type}**")
for prop in properties:
example = ""
if prop["type"] == "STRING":
if prop.get("distinct_count", 11) > DISTINCT_VALUE_LIMIT:
example = (
f'Example: "{prop["values"][0].replace(NL, " ")}"'
if prop["values"]
else ""
)
else: # If less than 10 possible values return all
example = (
(
"Available options: "
f'{[el.replace(NL, " ") for el in prop["values"]]}'
)
if prop["values"]
else ""
)
elif prop["type"] in ["INTEGER", "FLOAT", "DATE"]:
if prop.get("min"): # If we have min/max
example = f'Min: {prop["min"]}, Max: {prop["max"]}'
else: # return a single value
example = (
f'Example: "{prop["values"][0]}"' if prop["values"] else ""
)
elif prop["type"] == "LIST":
# Skip embeddings
if prop["min_size"] > LIST_LIMIT:
continue
example = (
f'Min Size: {prop["min_size"]}, Max Size: {prop["max_size"]}'
)
formatted_rel_props.append(
f" - `{prop['property']}: {prop['type']}` {example}"
)
else:
# Format node properties
for label, props in schema["node_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in props]
)
formatted_node_props.append(f"{label} {{{props_str}}}")
# Format relationship properties using structured_schema
for type, props in schema["rel_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in props]
)
formatted_rel_props.append(f"{type} {{{props_str}}}")
# Format relationships
formatted_rels = [
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
for el in schema["relationships"]
]
return "\n".join(
[
"Node properties:",
"\n".join(formatted_node_props),
"Relationship properties:",
"\n".join(formatted_rel_props),
"The relationships:",
"\n".join(formatted_rels),
]
)
class Neo4jGraph(GraphStore):
"""Neo4j database wrapper for various graph operations.
@@ -151,6 +372,8 @@ class Neo4jGraph(GraphStore):
embedding-like properties from database responses. Default is False.
refresh_schema (bool): A flag whether to refresh schema information
at initialization. Default is True.
enhanced_schema (bool): A flag whether to scan the database for
example values and use them in the graph schema. Default is False.
driver_config (Dict): Configuration passed to Neo4j Driver.
*Security note*: Make sure that the database connection uses credentials
@@ -176,6 +399,7 @@ class Neo4jGraph(GraphStore):
refresh_schema: bool = True,
*,
driver_config: Optional[Dict] = None,
enhanced_schema: bool = False,
) -> None:
"""Create a new Neo4j graph wrapper instance."""
try:
@@ -203,6 +427,7 @@ class Neo4jGraph(GraphStore):
self._database = database
self.timeout = timeout
self.sanitize = sanitize
self._enhanced_schema = enhanced_schema
self.schema: str = ""
self.structured_schema: Dict[str, Any] = {}
# Verify connection
@@ -300,37 +525,48 @@ class Neo4jGraph(GraphStore):
"metadata": {"constraint": constraint, "index": index},
}
# Format node properties
formatted_node_props = []
for el in node_properties:
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in el["properties"]]
if self._enhanced_schema:
schema_counts = self.query(
"CALL apoc.meta.graphSample() YIELD nodes, relationships "
"RETURN nodes, [rel in relationships | {name:apoc.any.property"
"(rel, 'type'), count: apoc.any.property(rel, 'count')}]"
" AS relationships"
)
formatted_node_props.append(f"{el['labels']} {{{props_str}}}")
# Update node info
for node in schema_counts[0]["nodes"]:
# Skip bloom labels
if node["name"] in EXCLUDED_LABELS:
continue
node_props = self.structured_schema["node_props"][node["name"]]
enhanced_cypher = _enhanced_schema_cypher(
node["name"], node_props, node["count"] < EXHAUSTIVE_SEARCH_LIMIT
)
enhanced_info = self.query(enhanced_cypher)[0]["output"]
for prop in node_props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
# Update rel info
for rel in schema_counts[0]["relationships"]:
# Skip bloom labels
if rel["name"] in EXCLUDED_RELS:
continue
rel_props = self.structured_schema["rel_props"].get(rel["name"])
if not rel_props:
continue
enhanced_cypher = _enhanced_schema_cypher(
rel["name"],
rel_props,
rel["count"] < EXHAUSTIVE_SEARCH_LIMIT,
is_relationship=True,
)
enhanced_info = self.query(enhanced_cypher)[0]["output"]
for prop in rel_props:
if prop["property"] in enhanced_info:
prop.update(enhanced_info[prop["property"]])
# Format relationship properties
formatted_rel_props = []
for el in rel_properties:
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in el["properties"]]
)
formatted_rel_props.append(f"{el['type']} {{{props_str}}}")
schema = _format_schema(self.structured_schema, self._enhanced_schema)
# Format relationships
formatted_rels = [
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" for el in relationships
]
self.schema = "\n".join(
[
"Node properties are the following:",
",".join(formatted_node_props),
"Relationship properties are the following:",
",".join(formatted_rel_props),
"The relationships are the following:",
",".join(formatted_rels),
]
)
self.schema = schema
def add_graph_documents(
self,

View File

@@ -291,3 +291,45 @@ def test_driver_config() -> None:
driver_config={"max_connection_pool_size": 1},
)
graph.query("RETURN 'foo'")
def test_enhanced_schema() -> None:
"""Test that neo4j works with driver config."""
url = os.environ.get("NEO4J_URI")
username = os.environ.get("NEO4J_USERNAME")
password = os.environ.get("NEO4J_PASSWORD")
assert url is not None
assert username is not None
assert password is not None
graph = Neo4jGraph(
url=url, username=username, password=password, enhanced_schema=True
)
graph.query("MATCH (n) DETACH DELETE n")
graph.add_graph_documents(test_data)
graph.refresh_schema()
expected_output = {
"node_props": {
"foo": [
{
"property": "id",
"type": "STRING",
"values": ["foo"],
"distinct_count": 1,
}
],
"bar": [
{
"property": "id",
"type": "STRING",
"values": ["bar"],
"distinct_count": 1,
}
],
},
"rel_props": {},
"relationships": [{"start": "foo", "type": "REL", "end": "bar"}],
}
# remove metadata portion of schema
del graph.structured_schema["metadata"]
assert graph.structured_schema == expected_output