mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat: add GraphRAG framework and integrate TuGraph (#1506)
Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""TuGraph Connector."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, cast
|
||||
from typing import Dict, List, cast
|
||||
|
||||
from .base import BaseConnector
|
||||
|
||||
@@ -12,37 +13,60 @@ class TuGraphConnector(BaseConnector):
|
||||
driver: str = "bolt"
|
||||
dialect: str = "tugraph"
|
||||
|
||||
def __init__(self, session):
|
||||
def __init__(self, driver, graph):
|
||||
"""Initialize the connector with a Neo4j driver."""
|
||||
self._session = session
|
||||
self._driver = driver
|
||||
self._schema = None
|
||||
self._graph = graph
|
||||
self._session = None
|
||||
|
||||
def create_graph(self, graph_name: str) -> None:
|
||||
"""Create a new graph."""
|
||||
# run the query to get vertex labels
|
||||
with self._driver.session(database="default") as session:
|
||||
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
|
||||
exists = any(item["graph_name"] == graph_name for item in graph_list)
|
||||
if not exists:
|
||||
session.run(f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)")
|
||||
|
||||
def delete_graph(self, graph_name: str) -> None:
|
||||
"""Delete a graph."""
|
||||
with self._driver.session(database="default") as session:
|
||||
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
|
||||
exists = any(item["graph_name"] == graph_name for item in graph_list)
|
||||
if exists:
|
||||
session.run(f"Call dbms.graph.deleteGraph('{graph_name}')")
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls, host: str, port: int, user: str, pwd: str, db_name: str, **kwargs: Any
|
||||
cls, host: str, port: int, user: str, pwd: str, db_name: str
|
||||
) -> "TuGraphConnector":
|
||||
"""Create a new TuGraphConnector from host, port, user, pwd, db_name."""
|
||||
try:
|
||||
from neo4j import GraphDatabase
|
||||
|
||||
db_url = f"{cls.driver}://{host}:{str(port)}"
|
||||
with GraphDatabase.driver(db_url, auth=(user, pwd)) as client:
|
||||
client.verify_connectivity()
|
||||
session = client.session(database=db_name)
|
||||
return cast(TuGraphConnector, cls(session=session))
|
||||
driver = GraphDatabase.driver(db_url, auth=(user, pwd))
|
||||
driver.verify_connectivity()
|
||||
return cast(TuGraphConnector, cls(driver=driver, graph=db_name))
|
||||
|
||||
except ImportError as err:
|
||||
raise ImportError("requests package is not installed") from err
|
||||
raise ImportError(
|
||||
"neo4j package is not installed, please install it with "
|
||||
"`pip install neo4j`"
|
||||
) from err
|
||||
|
||||
def get_table_names(self) -> Dict[str, List[str]]:
|
||||
"""Get all table names from the TuGraph database using the Neo4j driver."""
|
||||
# Run the query to get vertex labels
|
||||
v_result = self._session.run("CALL db.vertexLabels()").data()
|
||||
v_data = [table_name["label"] for table_name in v_result]
|
||||
"""Get all table names from the TuGraph by Neo4j driver."""
|
||||
# run the query to get vertex labels
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
v_result = session.run("CALL db.vertexLabels()").data()
|
||||
v_data = [table_name["label"] for table_name in v_result]
|
||||
|
||||
# Run the query to get edge labels
|
||||
e_result = self._session.run("CALL db.edgeLabels()").data()
|
||||
e_data = [table_name["label"] for table_name in e_result]
|
||||
return {"vertex_tables": v_data, "edge_tables": e_data}
|
||||
# run the query to get edge labels
|
||||
e_result = session.run("CALL db.edgeLabels()").data()
|
||||
e_data = [table_name["label"] for table_name in e_result]
|
||||
return {"vertex_tables": v_data, "edge_tables": e_data}
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
@@ -62,11 +86,13 @@ class TuGraphConnector(BaseConnector):
|
||||
|
||||
def close(self):
|
||||
"""Close the Neo4j driver."""
|
||||
self._session.close()
|
||||
self._driver.close()
|
||||
|
||||
def run(self):
|
||||
def run(self, query: str, fetch: str = "all") -> List:
|
||||
"""Run GQL."""
|
||||
return []
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
result = session.run(query)
|
||||
return list(result)
|
||||
|
||||
def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
|
||||
"""Get fields about specified graph.
|
||||
@@ -80,27 +106,27 @@ class TuGraphConnector(BaseConnector):
|
||||
eg:[{'name': 'id', 'type': 'int', 'default_expression': '',
|
||||
'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
"""
|
||||
data = []
|
||||
result = None
|
||||
if table_type == "vertex":
|
||||
result = self._session.run(
|
||||
f"CALL db.getVertexSchema('{table_name}')"
|
||||
).data()
|
||||
else:
|
||||
result = self._session.run(f"CALL db.getEdgeSchema('{table_name}')").data()
|
||||
schema_info = json.loads(result[0]["schema"])
|
||||
for prop in schema_info.get("properties", []):
|
||||
prop_dict = {
|
||||
"name": prop["name"],
|
||||
"type": prop["type"],
|
||||
"default_expression": "",
|
||||
"is_in_primary_key": bool(
|
||||
"primary" in schema_info and prop["name"] == schema_info["primary"]
|
||||
),
|
||||
"comment": prop["name"],
|
||||
}
|
||||
data.append(prop_dict)
|
||||
return data
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
data = []
|
||||
result = None
|
||||
if table_type == "vertex":
|
||||
result = session.run(f"CALL db.getVertexSchema('{table_name}')").data()
|
||||
else:
|
||||
result = session.run(f"CALL db.getEdgeSchema('{table_name}')").data()
|
||||
schema_info = json.loads(result[0]["schema"])
|
||||
for prop in schema_info.get("properties", []):
|
||||
prop_dict = {
|
||||
"name": prop["name"],
|
||||
"type": prop["type"],
|
||||
"default_expression": "",
|
||||
"is_in_primary_key": bool(
|
||||
"primary" in schema_info
|
||||
and prop["name"] == schema_info["primary"]
|
||||
),
|
||||
"comment": prop["name"],
|
||||
}
|
||||
data.append(prop_dict)
|
||||
return data
|
||||
|
||||
def get_indexes(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
|
||||
"""Get table indexes about specified table.
|
||||
@@ -112,14 +138,15 @@ class TuGraphConnector(BaseConnector):
|
||||
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
|
||||
"""
|
||||
# [{'name':'id','column_names':['id']}]
|
||||
result = self._session.run(
|
||||
f"CALL db.listLabelIndexes('{table_name}','{table_type}')"
|
||||
).data()
|
||||
transformed_data = []
|
||||
for item in result:
|
||||
new_dict = {"name": item["field"], "column_names": [item["field"]]}
|
||||
transformed_data.append(new_dict)
|
||||
return transformed_data
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
result = session.run(
|
||||
f"CALL db.listLabelIndexes('{table_name}','{table_type}')"
|
||||
).data()
|
||||
transformed_data = []
|
||||
for item in result:
|
||||
new_dict = {"name": item["field"], "column_names": [item["field"]]}
|
||||
transformed_data.append(new_dict)
|
||||
return transformed_data
|
||||
|
||||
@classmethod
|
||||
def is_graph_type(cls) -> bool:
|
||||
|
Reference in New Issue
Block a user