diff --git a/libs/community/langchain_community/graphs/kuzu_graph.py b/libs/community/langchain_community/graphs/kuzu_graph.py index eda7417f940..1f99f49fc94 100644 --- a/libs/community/langchain_community/graphs/kuzu_graph.py +++ b/libs/community/langchain_community/graphs/kuzu_graph.py @@ -36,10 +36,7 @@ class KuzuGraph: def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Kùzu database""" - params_list = [] - for param_name in params: - params_list.append([param_name, params[param_name]]) - result = self.conn.execute(query, params_list) + result = self.conn.execute(query, params) column_names = result.get_column_names() return_list = [] while result.has_next(): @@ -79,20 +76,16 @@ class KuzuGraph: rel_properties = [] for table in rel_tables: - current_table_schema = {"properties": [], "label": table["name"]} - properties_text = self.conn._connection.get_rel_property_names( - table["name"] - ).split("\n") - for i, line in enumerate(properties_text): - # The first 3 lines defines src, dst and name, so we skip them - if i < 3: - continue - if not line: - continue - property_name, property_type = line.strip().split(" ") - current_table_schema["properties"].append( - (property_name, property_type) - ) + table_name = table["name"] + current_table_schema = {"properties": [], "label": table_name} + query_result = self.conn.execute( + f"CALL table_info('{table_name}') RETURN *;" + ) + while query_result.has_next(): + row = query_result.get_next() + prop_name = row[1] + prop_type = row[2] + current_table_schema["properties"].append((prop_name, prop_type)) rel_properties.append(current_table_schema) self.schema = ( diff --git a/libs/community/tests/integration_tests/graphs/test_kuzu.py b/libs/community/tests/integration_tests/graphs/test_kuzu.py index 10a35f2175e..20d002d55d6 100644 --- a/libs/community/tests/integration_tests/graphs/test_kuzu.py +++ b/libs/community/tests/integration_tests/graphs/test_kuzu.py @@ -4,8 +4,7 @@ import unittest from langchain_community.graphs import KuzuGraph -EXPECTED_SCHEMA = """ -Node properties: [{'properties': [('name', 'STRING')], 'label': 'Movie'}, {'properties': [('name', 'STRING'), ('birthDate', 'STRING')], 'label': 'Person'}] +EXPECTED_SCHEMA = """Node properties: [{'properties': [('name', 'STRING')], 'label': 'Movie'}, {'properties': [('name', 'STRING'), ('birthDate', 'STRING')], 'label': 'Person'}] Relationships properties: [{'properties': [], 'label': 'ActedIn'}] Relationships: ['(:Person)-[:ActedIn]->(:Movie)'] """ # noqa: E501 @@ -36,7 +35,7 @@ class TestKuzu(unittest.TestCase): def tearDown(self) -> None: shutil.rmtree(self.tmpdir, ignore_errors=True) - def test_query(self) -> None: + def test_query_no_params(self) -> None: result = self.kuzu_graph.query("MATCH (n:Movie) RETURN n.name ORDER BY n.name") excepted_result = [ {"n.name": "The Godfather"}, @@ -45,6 +44,16 @@ class TestKuzu(unittest.TestCase): ] self.assertEqual(result, excepted_result) + def test_query_params(self) -> None: + result = self.kuzu_graph.query( + query="MATCH (n:Movie) WHERE n.name = $name RETURN n.name", + params={"name": "The Godfather"}, + ) + excepted_result = [ + {"n.name": "The Godfather"}, + ] + self.assertEqual(result, excepted_result) + def test_refresh_schema(self) -> None: self.conn.execute( "CREATE NODE TABLE Person (name STRING, birthDate STRING, PRIMARY "