mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-29 22:00:48 +00:00
fix: fix unit test error (#2085)
Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Appointat <kuda.czk@antgroup.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""TuGraph Connector."""
|
||||
|
||||
import json
|
||||
from typing import Dict, Generator, List, Tuple, cast
|
||||
from typing import Dict, Generator, Iterator, List, cast
|
||||
|
||||
from .base import BaseConnector
|
||||
|
||||
@@ -20,7 +20,7 @@ class TuGraphConnector(BaseConnector):
|
||||
self._graph = graph
|
||||
self._session = None
|
||||
|
||||
def create_graph(self, graph_name: str) -> None:
|
||||
def create_graph(self, graph_name: str) -> bool:
|
||||
"""Create a new graph in the database if it doesn't already exist."""
|
||||
try:
|
||||
with self._driver.session(database="default") as session:
|
||||
@@ -33,6 +33,8 @@ class TuGraphConnector(BaseConnector):
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
|
||||
|
||||
return not exists
|
||||
|
||||
def delete_graph(self, graph_name: str) -> None:
|
||||
"""Delete a graph in the database if it exists."""
|
||||
with self._driver.session(database="default") as session:
|
||||
@@ -60,20 +62,18 @@ class TuGraphConnector(BaseConnector):
|
||||
"`pip install neo4j`"
|
||||
) from err
|
||||
|
||||
def get_table_names(self) -> Tuple[List[str], List[str]]:
|
||||
def get_table_names(self) -> Iterator[str]:
|
||||
"""Get all table names from the TuGraph by Neo4j driver."""
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
# Run the query to get vertex labels
|
||||
raw_vertex_labels: Dict[str, str] = session.run(
|
||||
"CALL db.vertexLabels()"
|
||||
).data()
|
||||
raw_vertex_labels = session.run("CALL db.vertexLabels()").data()
|
||||
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
|
||||
|
||||
# Run the query to get edge labels
|
||||
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
|
||||
raw_edge_labels = session.run("CALL db.edgeLabels()").data()
|
||||
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
|
||||
|
||||
return vertex_labels, edge_labels
|
||||
return iter(vertex_labels + edge_labels)
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
|
||||
Reference in New Issue
Block a user