fix: fix unit test error (#2085)

Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Appointat <kuda.czk@antgroup.com>
This commit is contained in:
Florian
2024-10-22 09:35:51 +08:00
committed by GitHub
parent 6d6667812b
commit d9e20426fe
11 changed files with 129 additions and 113 deletions

View File

@@ -1,7 +1,7 @@
"""TuGraph Connector."""
import json
from typing import Dict, Generator, List, Tuple, cast
from typing import Dict, Generator, Iterator, List, cast
from .base import BaseConnector
@@ -20,7 +20,7 @@ class TuGraphConnector(BaseConnector):
self._graph = graph
self._session = None
def create_graph(self, graph_name: str) -> None:
def create_graph(self, graph_name: str) -> bool:
"""Create a new graph in the database if it doesn't already exist."""
try:
with self._driver.session(database="default") as session:
@@ -33,6 +33,8 @@ class TuGraphConnector(BaseConnector):
except Exception as e:
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
return not exists
def delete_graph(self, graph_name: str) -> None:
"""Delete a graph in the database if it exists."""
with self._driver.session(database="default") as session:
@@ -60,20 +62,18 @@ class TuGraphConnector(BaseConnector):
"`pip install neo4j`"
) from err
def get_table_names(self) -> Tuple[List[str], List[str]]:
def get_table_names(self) -> Iterator[str]:
"""Get all table names from the TuGraph by Neo4j driver."""
with self._driver.session(database=self._graph) as session:
# Run the query to get vertex labels
raw_vertex_labels: Dict[str, str] = session.run(
"CALL db.vertexLabels()"
).data()
raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
# Run the query to get edge labels
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
raw_edge_labels = session.run("CALL db.edgeLabels()").data()
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
return vertex_labels, edge_labels
return iter(vertex_labels + edge_labels)
def get_grants(self):
"""Get grants."""