mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
## Description: This PR addresses issue #29429 by fixing the _wrap_query method in langchain_community/graphs/age_graph.py. The method now correctly handles Cypher queries with UNION and EXCEPT operators, ensuring that the fields in the SQL query are ordered as they appear in the Cypher query. Additionally, the method now properly handles cases where RETURN * is not supported. ### Issue: #29429 ### Dependencies: None ### Add tests and docs: Added unit tests in tests/unit_tests/graphs/test_age_graph.py to validate the changes. No new integrations were added, so no example notebook is necessary. Lint and test: Ran make format, make lint, and make test to ensure code quality and functionality.
This commit is contained in:
parent
2f97916dea
commit
db1693aa70
@ -473,71 +473,78 @@ class AGEGraph(GraphStore):
|
||||
@staticmethod
|
||||
def _wrap_query(query: str, graph_name: str) -> str:
|
||||
"""
|
||||
Convert a cypher query to an Apache Age compatible
|
||||
sql query by wrapping the cypher query in ag_catalog.cypher,
|
||||
casting results to agtype and building a select statement
|
||||
Convert a Cyper query to an Apache Age compatible Sql Query.
|
||||
Handles combined queries with UNION/EXCEPT operators
|
||||
|
||||
Args:
|
||||
query (str): a valid cypher query
|
||||
graph_name (str): the name of the graph to query
|
||||
query (str) : A valid cypher query, can include UNION/EXCEPT operators
|
||||
graph_name (str) : The name of the graph to query
|
||||
|
||||
Returns:
|
||||
str: an equivalent pgsql query
|
||||
Returns :
|
||||
str : An equivalent pgSql query wrapped with ag_catalog.cypher
|
||||
|
||||
Raises:
|
||||
ValueError : If query is empty, contain RETURN *, or has invalid field names
|
||||
"""
|
||||
|
||||
if not query.strip():
|
||||
raise ValueError("Empty query provided")
|
||||
|
||||
# pgsql template
|
||||
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
||||
{query}
|
||||
$$) AS ({fields});"""
|
||||
|
||||
# if there are any returned fields they must be added to the pgsql query
|
||||
return_match = re.search(r'\breturn\b(?![^"]*")', query, re.IGNORECASE)
|
||||
if return_match:
|
||||
# Extract the part of the query after the RETURN keyword
|
||||
return_clause = query[return_match.end() :]
|
||||
# split the query into parts based on UNION and EXCEPT
|
||||
parts = re.split(r"\b(UNION\b|\bEXCEPT)\b", query, flags=re.IGNORECASE)
|
||||
|
||||
# parse return statement to identify returned fields
|
||||
fields = (
|
||||
return_clause.lower()
|
||||
.split("distinct")[-1]
|
||||
.split("order by")[0]
|
||||
.split("skip")[0]
|
||||
.split("limit")[0]
|
||||
.split(",")
|
||||
)
|
||||
all_fields = []
|
||||
|
||||
# raise exception if RETURN * is found as we can't resolve the fields
|
||||
if "*" in [x.strip() for x in fields]:
|
||||
raise ValueError(
|
||||
"AGE graph does not support 'RETURN *'"
|
||||
+ " statements in Cypher queries"
|
||||
for part in parts:
|
||||
if part.strip().upper() in ("UNION", "EXCEPT"):
|
||||
continue
|
||||
|
||||
# if there are any returned fields they must be added to the pgsql query
|
||||
return_match = re.search(r'\breturn\b(?![^"]*")', part, re.IGNORECASE)
|
||||
if return_match:
|
||||
# Extract the part of the query after the RETURN keyword
|
||||
return_clause = part[return_match.end() :]
|
||||
|
||||
# parse return statement to identify returned fields
|
||||
fields = (
|
||||
return_clause.lower()
|
||||
.split("distinct")[-1]
|
||||
.split("order by")[0]
|
||||
.split("skip")[0]
|
||||
.split("limit")[0]
|
||||
.split(",")
|
||||
)
|
||||
|
||||
# get pgsql formatted field names
|
||||
fields = [
|
||||
AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields)
|
||||
]
|
||||
# raise exception if RETURN * is found as we can't resolve the fields
|
||||
clean_fileds = [f.strip() for f in fields if f.strip()]
|
||||
if "*" in clean_fileds:
|
||||
raise ValueError(
|
||||
"Apache Age does not support RETURN * in Cypher queries"
|
||||
)
|
||||
|
||||
# build resulting pgsql relation
|
||||
fields_str = ", ".join(
|
||||
[
|
||||
field.split(".")[-1] + " agtype"
|
||||
for field in fields
|
||||
if field.split(".")[-1]
|
||||
]
|
||||
)
|
||||
# Format fields and maintain order of appearance
|
||||
for idx, field in enumerate(clean_fileds):
|
||||
field_name = AGEGraph._get_col_name(field, idx)
|
||||
if field_name not in all_fields:
|
||||
all_fields.append(field_name)
|
||||
|
||||
# if no return statement we still need to return a single field of type agtype
|
||||
else:
|
||||
# if no return statements found in any part
|
||||
if not all_fields:
|
||||
fields_str = "a agtype"
|
||||
|
||||
select_str = "*"
|
||||
else:
|
||||
fields_str = ", ".join(f"{field} agtype" for field in all_fields)
|
||||
|
||||
return template.format(
|
||||
graph_name=graph_name,
|
||||
query=query,
|
||||
fields=fields_str,
|
||||
projection=select_str,
|
||||
projection="*",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -53,6 +53,7 @@ class TestAGEGraph(unittest.TestCase):
|
||||
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])
|
||||
|
||||
def test_wrap_query(self) -> None:
|
||||
"""Test basic query wrapping functionality."""
|
||||
inputs = [
|
||||
# Positive case: Simple return clause
|
||||
"""
|
||||
@ -76,46 +77,195 @@ class TestAGEGraph(unittest.TestCase):
|
||||
|
||||
expected = [
|
||||
# Expected output for the first positive case
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (keanu:Person {name:'Keanu Reeves'})
|
||||
RETURN keanu.name AS name, keanu.born AS born
|
||||
$$) AS (name agtype, born agtype);
|
||||
""",
|
||||
# Second test case (no RETURN clause)
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MERGE (n:a {id: 1})
|
||||
$$) AS (a agtype);
|
||||
""",
|
||||
# Expected output for the negative cases (no RETURN clause)
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (n {description: "This will return a value"})
|
||||
MERGE (n)-[:RELATED]->(m)
|
||||
$$) AS (a agtype);
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (n {returnValue: "some value"})
|
||||
MERGE (n)-[:RELATED]->(m)
|
||||
$$) AS (a agtype);
|
||||
""",
|
||||
]
|
||||
|
||||
for idx, value in enumerate(inputs):
|
||||
result = AGEGraph._wrap_query(value, "test")
|
||||
expected_result = expected[idx]
|
||||
self.assertEqual(
|
||||
re.sub(r"\s", "", result),
|
||||
re.sub(r"\s", "", expected_result),
|
||||
(
|
||||
f"Failed on test case {idx + 1}\n"
|
||||
f"Input:\n{value}\n"
|
||||
f"Expected:\n{expected_result}\n"
|
||||
f"Got:\n{result}"
|
||||
),
|
||||
)
|
||||
|
||||
def test_wrap_query_union_except(self) -> None:
|
||||
"""Test query wrapping with UNION and EXCEPT operators."""
|
||||
inputs = [
|
||||
# UNION case
|
||||
"""
|
||||
MATCH (n:Person)
|
||||
RETURN n.name AS name, n.age AS age
|
||||
UNION
|
||||
MATCH (n:Employee)
|
||||
RETURN n.name AS name, n.salary AS salary
|
||||
""",
|
||||
"""
|
||||
MATCH (a:Employee {name: "Alice"})
|
||||
RETURN a.name AS name
|
||||
UNION
|
||||
MATCH (b:Manager {name: "Bob"})
|
||||
RETURN b.name AS name
|
||||
""",
|
||||
# Complex UNION case
|
||||
"""
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
UNION
|
||||
MATCH (m)-[r]->(n)
|
||||
RETURN m.name AS source, type(r) AS relationship, n.name AS target
|
||||
""",
|
||||
"""
|
||||
MATCH (a:Person)-[:FRIEND]->(b:Person)
|
||||
WHERE a.age > 30
|
||||
RETURN a.name AS name
|
||||
UNION
|
||||
MATCH (c:Person)-[:FRIEND]->(d:Person)
|
||||
WHERE c.age < 25
|
||||
RETURN c.name AS name
|
||||
""",
|
||||
# EXCEPT case
|
||||
"""
|
||||
MATCH (n:Person)
|
||||
RETURN n.name AS name
|
||||
EXCEPT
|
||||
MATCH (n:Employee)
|
||||
RETURN n.name AS name
|
||||
""",
|
||||
"""
|
||||
MATCH (a:Person)
|
||||
RETURN a.name AS name, a.age AS age
|
||||
EXCEPT
|
||||
MATCH (b:Person {name: "Alice", age: 30})
|
||||
RETURN b.name AS name, b.age AS age
|
||||
""",
|
||||
]
|
||||
|
||||
expected = [
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (keanu:Person {name:'Keanu Reeves'})
|
||||
RETURN keanu.name AS name, keanu.born AS born
|
||||
$$) AS (name agtype, born agtype);
|
||||
MATCH (n:Person)
|
||||
RETURN n.name AS name, n.age AS age
|
||||
UNION
|
||||
MATCH (n:Employee)
|
||||
RETURN n.name AS name, n.salary AS salary
|
||||
$$) AS (name agtype, age agtype, salary agtype);
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MERGE (n:a {id: 1})
|
||||
$$) AS (a agtype);
|
||||
""",
|
||||
# Expected output for the negative cases (no return clause)
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (n {description: "This will return a value"})
|
||||
MERGE (n)-[:RELATED]->(m)
|
||||
$$) AS (a agtype);
|
||||
MATCH (a:Employee {name: "Alice"})
|
||||
RETURN a.name AS name
|
||||
UNION
|
||||
MATCH (b:Manager {name: "Bob"})
|
||||
RETURN b.name AS name
|
||||
$$) AS (name agtype);
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (n {returnValue: "some value"})
|
||||
MERGE (n)-[:RELATED]->(m)
|
||||
$$) AS (a agtype);
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
UNION
|
||||
MATCH (m)-[r]->(n)
|
||||
RETURN m.name AS source, type(r) AS relationship, n.name AS target
|
||||
$$) AS (source agtype, relationship agtype, target agtype);
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (a:Person)-[:FRIEND]->(b:Person)
|
||||
WHERE a.age > 30
|
||||
RETURN a.name AS name
|
||||
UNION
|
||||
MATCH (c:Person)-[:FRIEND]->(d:Person)
|
||||
WHERE c.age < 25
|
||||
RETURN c.name AS name
|
||||
$$) AS (name agtype);
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (n:Person)
|
||||
RETURN n.name AS name
|
||||
EXCEPT
|
||||
MATCH (n:Employee)
|
||||
RETURN n.name AS name
|
||||
$$) AS (name agtype);
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM ag_catalog.cypher('test', $$
|
||||
MATCH (a:Person)
|
||||
RETURN a.name AS name, a.age AS age
|
||||
EXCEPT
|
||||
MATCH (b:Person {name: "Alice", age: 30})
|
||||
RETURN b.name AS name, b.age AS age
|
||||
$$) AS (name agtype, age agtype);
|
||||
""",
|
||||
]
|
||||
|
||||
for idx, value in enumerate(inputs):
|
||||
result = AGEGraph._wrap_query(value, "test")
|
||||
expected_result = expected[idx]
|
||||
self.assertEqual(
|
||||
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
|
||||
re.sub(r"\s", "", expected[idx]),
|
||||
re.sub(r"\s", "", result),
|
||||
re.sub(r"\s", "", expected_result),
|
||||
(
|
||||
f"Failed on test case {idx + 1}\n"
|
||||
f"Input:\n{value}\n"
|
||||
f"Expected:\n{expected_result}\n"
|
||||
f"Got:\n{result}"
|
||||
),
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
AGEGraph._wrap_query(
|
||||
"""
|
||||
def test_wrap_query_errors(self) -> None:
|
||||
"""Test error cases for query wrapping."""
|
||||
error_cases = [
|
||||
# Empty query
|
||||
"",
|
||||
# Return * case
|
||||
"""
|
||||
MATCH ()
|
||||
RETURN *
|
||||
""",
|
||||
"test",
|
||||
)
|
||||
# Return * in UNION
|
||||
"""
|
||||
MATCH (n:Person)
|
||||
RETURN n.name
|
||||
UNION
|
||||
MATCH ()
|
||||
RETURN *
|
||||
""",
|
||||
]
|
||||
|
||||
for query in error_cases:
|
||||
with self.assertRaises(ValueError):
|
||||
AGEGraph._wrap_query(query, "test")
|
||||
|
||||
def test_format_properties(self) -> None:
|
||||
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]
|
||||
|
Loading…
Reference in New Issue
Block a user