diff --git a/libs/community/langchain_community/graphs/age_graph.py b/libs/community/langchain_community/graphs/age_graph.py index a64113856d5..4fba85731ad 100644 --- a/libs/community/langchain_community/graphs/age_graph.py +++ b/libs/community/langchain_community/graphs/age_graph.py @@ -491,11 +491,14 @@ class AGEGraph(GraphStore): $$) AS ({fields});""" # if there are any returned fields they must be added to the pgsql query - if "return" in query.lower(): + return_match = re.search(r'\breturn\b(?![^"]*")', query, re.IGNORECASE) + if return_match: + # Extract the part of the query after the RETURN keyword + return_clause = query[return_match.end() :] + # parse return statement to identify returned fields fields = ( - query.lower() - .split("return")[-1] + return_clause.lower() .split("distinct")[-1] .split("order by")[0] .split("skip")[0] @@ -517,7 +520,11 @@ class AGEGraph(GraphStore): # build resulting pgsql relation fields_str = ", ".join( - [field.split(".")[-1] + " agtype" for field in fields] + [ + field.split(".")[-1] + " agtype" + for field in fields + if field.split(".")[-1] + ] ) # if no return statement we still need to return a single field of type agtype diff --git a/libs/community/tests/unit_tests/graphs/test_age_graph.py b/libs/community/tests/unit_tests/graphs/test_age_graph.py index 7b9044eb15d..6981c16d88a 100644 --- a/libs/community/tests/unit_tests/graphs/test_age_graph.py +++ b/libs/community/tests/unit_tests/graphs/test_age_graph.py @@ -54,6 +54,7 @@ class TestAGEGraph(unittest.TestCase): def test_wrap_query(self) -> None: inputs = [ + # Positive case: Simple return clause """ MATCH (keanu:Person {name:'Keanu Reeves'}) RETURN keanu.name AS name, keanu.born AS born @@ -61,9 +62,20 @@ class TestAGEGraph(unittest.TestCase): """ MERGE (n:a {id: 1}) """, + # Negative case: Return in a string value + """ + MATCH (n {description: "This will return a value"}) + MERGE (n)-[:RELATED]->(m) + """, + # Negative case: Return in a property key + """ + MATCH (n {returnValue: "some value"}) + MERGE (n)-[:RELATED]->(m) + """, ] expected = [ + # Expected output for the first positive case """ SELECT * FROM ag_catalog.cypher('test', $$ MATCH (keanu:Person {name:'Keanu Reeves'}) @@ -75,6 +87,19 @@ class TestAGEGraph(unittest.TestCase): MERGE (n:a {id: 1}) $$) AS (a agtype); """, + # Expected output for the negative cases (no return clause) + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MATCH (n {description: "This will return a value"}) + MERGE (n)-[:RELATED]->(m) + $$) AS (a agtype); + """, + """ + SELECT * FROM ag_catalog.cypher('test', $$ + MATCH (n {returnValue: "some value"}) + MERGE (n)-[:RELATED]->(m) + $$) AS (a agtype); + """, ] for idx, value in enumerate(inputs):