mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-16 20:42:01 +00:00
Description: Current AGEGraph() implementation does some custom wrapping for graph queries. The method here is _wrap_query() as it parse the field from the original query to add some SQL context to it. This improves the current parsing logic to cover additional edge cases that are added to the test coverage, basically if any Node property name or value has the "return" literal in it will break the graph / SQL query. We discovered this while dealing with real world datasets, is not an uncommon scenario and I think it needs to be covered.
171 lines
5.1 KiB
Python
171 lines
5.1 KiB
Python
import re
|
|
import unittest
|
|
from collections import namedtuple
|
|
from typing import Any, Dict, List
|
|
|
|
from langchain_community.graphs.age_graph import AGEGraph
|
|
|
|
|
|
class TestAGEGraph(unittest.TestCase):
|
|
def test_format_triples(self) -> None:
|
|
test_input = [
|
|
{"start": "from_a", "type": "edge_a", "end": "to_a"},
|
|
{"start": "from_b", "type": "edge_b", "end": "to_b"},
|
|
]
|
|
|
|
expected = [
|
|
"(:`from_a`)-[:`edge_a`]->(:`to_a`)",
|
|
"(:`from_b`)-[:`edge_b`]->(:`to_b`)",
|
|
]
|
|
|
|
self.assertEqual(AGEGraph._format_triples(test_input), expected)
|
|
|
|
def test_get_col_name(self) -> None:
|
|
inputs = [
|
|
("a", 1),
|
|
("a as b", 1),
|
|
(" c ", 1),
|
|
(" c as d ", 1),
|
|
("sum(a)", 1),
|
|
("sum(a) as b", 1),
|
|
("count(*)", 1),
|
|
("count(*) as cnt", 1),
|
|
("true", 1),
|
|
("false", 1),
|
|
("null", 1),
|
|
]
|
|
|
|
expected = [
|
|
"a",
|
|
"b",
|
|
"c",
|
|
"d",
|
|
"sum_a",
|
|
"b",
|
|
"count_*",
|
|
"cnt",
|
|
"column_1",
|
|
"column_1",
|
|
"column_1",
|
|
]
|
|
|
|
for idx, value in enumerate(inputs):
|
|
self.assertEqual(AGEGraph._get_col_name(*value), expected[idx])
|
|
|
|
def test_wrap_query(self) -> None:
|
|
inputs = [
|
|
# Positive case: Simple return clause
|
|
"""
|
|
MATCH (keanu:Person {name:'Keanu Reeves'})
|
|
RETURN keanu.name AS name, keanu.born AS born
|
|
""",
|
|
"""
|
|
MERGE (n:a {id: 1})
|
|
""",
|
|
# Negative case: Return in a string value
|
|
"""
|
|
MATCH (n {description: "This will return a value"})
|
|
MERGE (n)-[:RELATED]->(m)
|
|
""",
|
|
# Negative case: Return in a property key
|
|
"""
|
|
MATCH (n {returnValue: "some value"})
|
|
MERGE (n)-[:RELATED]->(m)
|
|
""",
|
|
]
|
|
|
|
expected = [
|
|
# Expected output for the first positive case
|
|
"""
|
|
SELECT * FROM ag_catalog.cypher('test', $$
|
|
MATCH (keanu:Person {name:'Keanu Reeves'})
|
|
RETURN keanu.name AS name, keanu.born AS born
|
|
$$) AS (name agtype, born agtype);
|
|
""",
|
|
"""
|
|
SELECT * FROM ag_catalog.cypher('test', $$
|
|
MERGE (n:a {id: 1})
|
|
$$) AS (a agtype);
|
|
""",
|
|
# Expected output for the negative cases (no return clause)
|
|
"""
|
|
SELECT * FROM ag_catalog.cypher('test', $$
|
|
MATCH (n {description: "This will return a value"})
|
|
MERGE (n)-[:RELATED]->(m)
|
|
$$) AS (a agtype);
|
|
""",
|
|
"""
|
|
SELECT * FROM ag_catalog.cypher('test', $$
|
|
MATCH (n {returnValue: "some value"})
|
|
MERGE (n)-[:RELATED]->(m)
|
|
$$) AS (a agtype);
|
|
""",
|
|
]
|
|
|
|
for idx, value in enumerate(inputs):
|
|
self.assertEqual(
|
|
re.sub(r"\s", "", AGEGraph._wrap_query(value, "test")),
|
|
re.sub(r"\s", "", expected[idx]),
|
|
)
|
|
|
|
with self.assertRaises(ValueError):
|
|
AGEGraph._wrap_query(
|
|
"""
|
|
MATCH ()
|
|
RETURN *
|
|
""",
|
|
"test",
|
|
)
|
|
|
|
def test_format_properties(self) -> None:
|
|
inputs: List[Dict[str, Any]] = [{}, {"a": "b"}, {"a": "b", "c": 1, "d": True}]
|
|
|
|
expected = ["{}", '{`a`: "b"}', '{`a`: "b", `c`: 1, `d`: true}']
|
|
|
|
for idx, value in enumerate(inputs):
|
|
self.assertEqual(AGEGraph._format_properties(value), expected[idx])
|
|
|
|
def test_clean_graph_labels(self) -> None:
|
|
inputs = ["label", "label 1", "label#$"]
|
|
|
|
expected = ["label", "label_1", "label_"]
|
|
|
|
for idx, value in enumerate(inputs):
|
|
self.assertEqual(AGEGraph.clean_graph_labels(value), expected[idx])
|
|
|
|
def test_record_to_dict(self) -> None:
|
|
Record = namedtuple("Record", ["node1", "edge", "node2"])
|
|
r = Record(
|
|
node1='{"id": 1, "label": "label1", "properties":'
|
|
+ ' {"prop": "a"}}::vertex',
|
|
edge='{"id": 3, "label": "edge", "end_id": 2, '
|
|
+ '"start_id": 1, "properties": {"test": "abc"}}::edge',
|
|
node2='{"id": 2, "label": "label1", '
|
|
+ '"properties": {"prop": "b"}}::vertex',
|
|
)
|
|
|
|
result = AGEGraph._record_to_dict(r)
|
|
|
|
expected = {
|
|
"node1": {"prop": "a"},
|
|
"edge": ({"prop": "a"}, "edge", {"prop": "b"}),
|
|
"node2": {"prop": "b"},
|
|
}
|
|
|
|
self.assertEqual(result, expected)
|
|
|
|
Record2 = namedtuple("Record2", ["string", "int", "float", "bool", "null"])
|
|
r2 = Record2('"test"', "1", "1.5", "true", None)
|
|
|
|
result = AGEGraph._record_to_dict(r2)
|
|
|
|
expected2 = {
|
|
"string": "test",
|
|
"int": 1,
|
|
"float": 1.5,
|
|
"bool": True,
|
|
"null": None,
|
|
}
|
|
|
|
self.assertEqual(result, expected2)
|