community: fix issue #29429 in age_graph.py (#29506)

## Description:

This PR addresses issue #29429 by fixing the _wrap_query method in
langchain_community/graphs/age_graph.py. The method now correctly
handles Cypher queries with UNION and EXCEPT operators, ensuring that
the fields in the SQL query are ordered as they appear in the Cypher
query. Additionally, the method now properly handles cases where RETURN
* is not supported.

### Issue: #29429

### Dependencies: None


### Add tests and docs:

Added unit tests in tests/unit_tests/graphs/test_age_graph.py to
validate the changes.
No new integrations were added, so no example notebook is necessary.
Lint and test:

Ran make format, make lint, and make test to ensure code quality and
functionality.
This commit is contained in:
Hemant Rawat 2025-02-02 07:54:45 +05:30 committed by GitHub
parent 2f97916dea
commit db1693aa70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 221 additions and 64 deletions

View File

@ -473,71 +473,78 @@ class AGEGraph(GraphStore):
@staticmethod
def _wrap_query(query: str, graph_name: str) -> str:
"""
Convert a cypher query to an Apache Age compatible
sql query by wrapping the cypher query in ag_catalog.cypher,
casting results to agtype and building a select statement
Convert a Cyper query to an Apache Age compatible Sql Query.
Handles combined queries with UNION/EXCEPT operators
Args:
query (str): a valid cypher query
graph_name (str): the name of the graph to query
query (str) : A valid cypher query, can include UNION/EXCEPT operators
graph_name (str) : The name of the graph to query
Returns:
str: an equivalent pgsql query
Returns :
str : An equivalent pgSql query wrapped with ag_catalog.cypher
Raises:
ValueError : If query is empty, contain RETURN *, or has invalid field names
"""
if not query.strip():
raise ValueError("Empty query provided")
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""
# if there are any returned fields they must be added to the pgsql query
return_match = re.search(r'\breturn\b(?![^"]*")', query, re.IGNORECASE)
if return_match:
# Extract the part of the query after the RETURN keyword
return_clause = query[return_match.end() :]
# split the query into parts based on UNION and EXCEPT
parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE)
# parse return statement to identify returned fields
fields = (
return_clause.lower()
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
all_fields = []
# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
for part in parts:
if part.strip().upper() in ("UNION", "EXCEPT"):
continue
# if there are any returned fields they must be added to the pgsql query
return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE)
if return_match:
# Extract the part of the query after the RETURN keyword
return_clause = part[return_match.end() :]
# parse return statement to identify returned fields
fields = (
return_clause.lower()
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# get pgsql formatted field names
fields = [
AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields)
]
# raise exception if RETURN * is found as we can't resolve the fields
clean_fileds = [f.strip() for f in fields if f.strip()]
if "*" in clean_fileds:
raise ValueError(
"Apache Age does not support RETURN * in Cypher queries"
)
# build resulting pgsql relation
fields_str = ", ".join(
[
field.split(".")[-1] + " agtype"
for field in fields
if field.split(".")[-1]
]
)
# Format fields and maintain order of appearance
for idx, field in enumerate(clean_fileds):
field_name = AGEGraph._get_col_name(field, idx)
if field_name not in all_fields:
all_fields.append(field_name)
# if no return statement we still need to return a single field of type agtype
else:
# if no return statements found in any part
if not all_fields:
fields_str = "a agtype"
select_str = "*"
else:
fields_str = ", ".join(f"{field} agtype" for field in all_fields)
return template.format(
graph_name=graph_name,
query=query,
fields=fields_str,
projection=select_str,
projection="*",
)
@staticmethod

View File

@ -53,6 +53,7 @@ class TestAGEGraph(unittest.TestCase):
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])
def test_wrap_query(self) -> None:
"""Test basic query wrapping functionality."""
inputs = [
# Positive case: Simple return clause
"""
@ -76,46 +77,195 @@ class TestAGEGraph(unittest.TestCase):
expected = [
# Expected output for the first positive case
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
$$) AS (name agtype, born agtype);
""",
# Second test case (no RETURN clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MERGE (n:a {id: 1})
$$) AS (a agtype);
""",
# Expected output for the negative cases (no RETURN clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
]
for idx, value in enumerate(inputs):
result = AGEGraph._wrap_query(value, "test")
expected_result = expected[idx]
self.assertEqual(
re.sub(r"\s", "", result),
re.sub(r"\s", "", expected_result),
(
f"Failed on test case {idx + 1}\n"
f"Input:\n{value}\n"
f"Expected:\n{expected_result}\n"
f"Got:\n{result}"
),
)
def test_wrap_query_union_except(self) -> None:
"""Test query wrapping with UNION and EXCEPT operators."""
inputs = [
# UNION case
"""
MATCH (n:Person)
RETURN n.name AS name, n.age AS age
UNION
MATCH (n:Employee)
RETURN n.name AS name, n.salary AS salary
""",
"""
MATCH (a:Employee {name: "Alice"})
RETURN a.name AS name
UNION
MATCH (b:Manager {name: "Bob"})
RETURN b.name AS name
""",
# Complex UNION case
"""
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
UNION
MATCH (m)-[r]->(n)
RETURN m.name AS source, type(r) AS relationship, n.name AS target
""",
"""
MATCH (a:Person)-[:FRIEND]->(b:Person)
WHERE a.age > 30
RETURN a.name AS name
UNION
MATCH (c:Person)-[:FRIEND]->(d:Person)
WHERE c.age < 25
RETURN c.name AS name
""",
# EXCEPT case
"""
MATCH (n:Person)
RETURN n.name AS name
EXCEPT
MATCH (n:Employee)
RETURN n.name AS name
""",
"""
MATCH (a:Person)
RETURN a.name AS name, a.age AS age
EXCEPT
MATCH (b:Person {name: "Alice", age: 30})
RETURN b.name AS name, b.age AS age
""",
]
expected = [
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
$$) AS (name agtype, born agtype);
MATCH (n:Person)
RETURN n.name AS name, n.age AS age
UNION
MATCH (n:Employee)
RETURN n.name AS name, n.salary AS salary
$$) AS (name agtype, age agtype, salary agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MERGE (n:a {id: 1})
$$) AS (a agtype);
""",
# Expected output for the negative cases (no return clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
MATCH (a:Employee {name: "Alice"})
RETURN a.name AS name
UNION
MATCH (b:Manager {name: "Bob"})
RETURN b.name AS name
$$) AS (name agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
UNION
MATCH (m)-[r]->(n)
RETURN m.name AS source, type(r) AS relationship, n.name AS target
$$) AS (source agtype, relationship agtype, target agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (a:Person)-[:FRIEND]->(b:Person)
WHERE a.age > 30
RETURN a.name AS name
UNION
MATCH (c:Person)-[:FRIEND]->(d:Person)
WHERE c.age < 25
RETURN c.name AS name
$$) AS (name agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n:Person)
RETURN n.name AS name
EXCEPT
MATCH (n:Employee)
RETURN n.name AS name
$$) AS (name agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (a:Person)
RETURN a.name AS name, a.age AS age
EXCEPT
MATCH (b:Person {name: "Alice", age: 30})
RETURN b.name AS name, b.age AS age
$$) AS (name agtype, age agtype);
""",
]
for idx, value in enumerate(inputs):
result = AGEGraph._wrap_query(value, "test")
expected_result = expected[idx]
self.assertEqual(
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
re.sub(r"\s", "", expected[idx]),
re.sub(r"\s", "", result),
re.sub(r"\s", "", expected_result),
(
f"Failed on test case {idx + 1}\n"
f"Input:\n{value}\n"
f"Expected:\n{expected_result}\n"
f"Got:\n{result}"
),
)
with self.assertRaises(ValueError):
AGEGraph._wrap_query(
"""
def test_wrap_query_errors(self) -> None:
"""Test error cases for query wrapping."""
error_cases = [
# Empty query
"",
# Return * case
"""
MATCH ()
RETURN *
""",
"test",
)
# Return * in UNION
"""
MATCH (n:Person)
RETURN n.name
UNION
MATCH ()
RETURN *
""",
]
for query in error_cases:
with self.assertRaises(ValueError):
AGEGraph._wrap_query(query, "test")
def test_format_properties(self) -> None:
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]