community[patch]: Add neo4j timeout and value sanitization option (#16138)

The timeout function comes in handy when you want to kill longrunning
queries.
The value sanitization removes all lists that are larger than 128
elements. The idea here is to remove embedding properties from results.
This commit is contained in:
Tomaz Bratanic
2024-01-17 22:22:19 +01:00
committed by GitHub
parent 27ed2673da
commit 1e80113ac9
3 changed files with 104 additions and 3 deletions

View File

@@ -31,8 +31,50 @@ RETURN {start: label, type: property, end: toString(other_node)} AS output
"""
def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
"""
Sanitizes the input dictionary by removing embedding-like values,
lists with more than 128 elements, that are mostly irrelevant for
generating answers in a LLM context. These properties, if left in
results, can occupy significant context space and detract from
the LLM's performance by introducing unnecessary noise and cost.
"""
LIST_LIMIT = 128
# Create a new dictionary to avoid changing size during iteration
new_dict = {}
for key, value in d.items():
if isinstance(value, dict):
# Recurse to handle nested dictionaries
new_dict[key] = value_sanitize(value)
elif isinstance(value, list):
# check if it has less than LIST_LIMIT values
if len(value) < LIST_LIMIT:
# if value is a list, check if it contains dictionaries to clean
cleaned_list = []
for item in value:
if isinstance(item, dict):
cleaned_list.append(value_sanitize(item))
else:
cleaned_list.append(item)
new_dict[key] = cleaned_list
else:
new_dict[key] = value
return new_dict
class Neo4jGraph(GraphStore):
"""Neo4j wrapper for graph operations.
"""Provides a connection to a Neo4j database for various graph operations.
Parameters:
url (Optional[str]): The URL of the Neo4j database server.
username (Optional[str]): The username for database authentication.
password (Optional[str]): The password for database authentication.
database (str): The name of the database to connect to. Default is 'neo4j'.
timeout (Optional[float]): The timeout for transactions in seconds.
Useful for terminating long-running queries.
By default, there is no timeout set.
sanitize (bool): A flag to indicate whether to remove lists with
more than 128 elements from results. Useful for removing
embedding-like properties from database responses. Default is False.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
@@ -52,6 +94,8 @@ class Neo4jGraph(GraphStore):
username: Optional[str] = None,
password: Optional[str] = None,
database: str = "neo4j",
timeout: Optional[float] = None,
sanitize: bool = False,
) -> None:
"""Create a new Neo4j graph wrapper instance."""
try:
@@ -69,6 +113,8 @@ class Neo4jGraph(GraphStore):
self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self.timeout = timeout
self.sanitize = sanitize
self.schema: str = ""
self.structured_schema: Dict[str, Any] = {}
# Verify connection
@@ -106,12 +152,16 @@ class Neo4jGraph(GraphStore):
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""Query Neo4j database."""
from neo4j import Query
from neo4j.exceptions import CypherSyntaxError
with self._driver.session(database=self._database) as session:
try:
data = session.run(query, params)
return [r.data() for r in data]
data = session.run(Query(text=query, timeout=self.timeout), params)
json_data = [r.data() for r in data]
if self.sanitize:
json_data = value_sanitize(json_data)
return json_data
except CypherSyntaxError as e:
raise ValueError(f"Generated Cypher Statement is not valid\n{e}")