langchain/libs/community/tests/unit_tests/graphs/test_age_graph.py
German Martin d5d18c62b3
community: Apache AGE wrapper additional edge cases. (#28151)
Description: 
Current AGEGraph() implementation does some custom wrapping for graph
queries. The method here is _wrap_query() as it parse the field from the
original query to add some SQL context to it.
This improves the current parsing logic to cover additional edge cases
that are added to the test coverage, basically if any Node property name
or value has the "return" literal in it will break the graph / SQL
query.
We discovered this while dealing with real world datasets, is not an
uncommon scenario and I think it needs to be covered.
2024-12-16 11:28:01 -05:00

171 lines
5.1 KiB
Python

import re
import unittest
from collections import namedtuple
from typing import Any, Dict, List
from langchain_community.graphs.age_graph import AGEGraph
class TestAGEGraph(unittest.TestCase):
def test_format_triples(self) -> None:
test_input = [
{"start": "from_a", "type": "edge_a", "end": "to_a"},
{"start": "from_b", "type": "edge_b", "end": "to_b"},
]
expected = [
"(:`from_a`)-[:`edge_a`]->(:`to_a`)",
"(:`from_b`)-[:`edge_b`]->(:`to_b`)",
]
self.assertEqual(AGEGraph._format_triples(test_input), expected)
def test_get_col_name(self) -> None:
inputs = [
("a", 1),
("a as b", 1),
(" c ", 1),
(" c as d ", 1),
("sum(a)", 1),
("sum(a) as b", 1),
("count(*)", 1),
("count(*) as cnt", 1),
("true", 1),
("false", 1),
("null", 1),
]
expected = [
"a",
"b",
"c",
"d",
"sum_a",
"b",
"count_*",
"cnt",
"column_1",
"column_1",
"column_1",
]
for idx, value in enumerate(inputs):
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])
def test_wrap_query(self) -> None:
inputs = [
# Positive case: Simple return clause
"""
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
""",
"""
MERGE (n:a {id: 1})
""",
# Negative case: Return in a string value
"""
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
""",
# Negative case: Return in a property key
"""
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
""",
]
expected = [
# Expected output for the first positive case
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (keanu:Person {name:'Keanu Reeves'})
RETURN keanu.name AS name, keanu.born AS born
$$) AS (name agtype, born agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MERGE (n:a {id: 1})
$$) AS (a agtype);
""",
# Expected output for the negative cases (no return clause)
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {description: "This will return a value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
"""
SELECT * FROM ag_catalog.cypher('test', $$
MATCH (n {returnValue: "some value"})
MERGE (n)-[:RELATED]->(m)
$$) AS (a agtype);
""",
]
for idx, value in enumerate(inputs):
self.assertEqual(
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
re.sub(r"\s", "", expected[idx]),
)
with self.assertRaises(ValueError):
AGEGraph._wrap_query(
"""
MATCH ()
RETURN *
""",
"test",
)
def test_format_properties(self) -> None:
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]
expected = ["{}", '{`a`: "b"}', '{`a`: "b", `c`: 1, `d`: true}']
for idx, value in enumerate(inputs):
self.assertEqual(AGEGraph._format_properties(value), expected[idx])
def test_clean_graph_labels(self) -> None:
inputs = ["label", "label 1", "label#$"]
expected = ["label", "label_1", "label_"]
for idx, value in enumerate(inputs):
self.assertEqual(AGEGraph.clean_graph_labels(value), expected[idx])
def test_record_to_dict(self) -> None:
Record = namedtuple("Record", ["node1", "edge", "node2"])
r = Record(
node1='{"id": 1, "label": "label1", "properties":'
+ ' {"prop": "a"}}::vertex',
edge='{"id": 3, "label": "edge", "end_id": 2, '
+ '"start_id": 1, "properties": {"test": "abc"}}::edge',
node2='{"id": 2, "label": "label1", '
+ '"properties": {"prop": "b"}}::vertex',
)
result = AGEGraph._record_to_dict(r)
expected = {
"node1": {"prop": "a"},
"edge": ({"prop": "a"}, "edge", {"prop": "b"}),
"node2": {"prop": "b"},
}
self.assertEqual(result, expected)
Record2 = namedtuple("Record2", ["string", "int", "float", "bool", "null"])
r2 = Record2('"test"', "1", "1.5", "true", None)
result = AGEGraph._record_to_dict(r2)
expected2 = {
"string": "test",
"int": 1,
"float": 1.5,
"bool": True,
"null": None,
}
self.assertEqual(result, expected2)