feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)

Co-authored-by: Florian <fanzhidongyzby@163.com>
Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
M1n9X
2024-08-30 21:59:44 +08:00
committed by GitHub
parent 471689ba20
commit 759f7d99cc
59 changed files with 29316 additions and 411 deletions

View File

@@ -1,12 +1,14 @@
"""TuGraph vector store."""
"""TuGraph store."""
import base64
import json
import logging
import os
from typing import List, Optional, Tuple
from typing import Any, Generator, Iterator, List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.datasource.conn_tugraph import TuGraphConnector
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph, Vertex
logger = logging.getLogger(__name__)
@@ -29,20 +31,24 @@ class TuGraphStoreConfig(GraphStoreConfig):
description="login username",
)
password: str = Field(
default="123456",
default="73@TuGraph",
description="login password",
)
vertex_type: str = Field(
default="entity",
description="The type of graph vertex, `entity` by default.",
description="The type of vertex, `entity` by default.",
)
edge_type: str = Field(
default="relation",
description="The type of graph edge, `relation` by default.",
description="The type of edge, `relation` by default.",
)
edge_name_key: str = Field(
default="label",
description="The label of edge name, `label` by default.",
plugin_names: List[str] = Field(
default=["leiden"],
description=(
"Plugins need to be loaded when initialize TuGraph, "
"code: https://github.com/TuGraph-family"
"/dbgpt-tugraph-plugins/tree/master/cpp"
),
)
@@ -51,20 +57,23 @@ class TuGraphStore(GraphStoreBase):
def __init__(self, config: TuGraphStoreConfig) -> None:
"""Initialize the TuGraphStore with connection details."""
self._host = os.getenv("TUGRAPH_HOST", "127.0.0.1") or config.host
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
self._password = os.getenv("TUGRAPH_PASSWORD", "73@TuGraph") or config.password
self._node_label = (
os.getenv("TUGRAPH_VERTEX_TYPE", "entity") or config.vertex_type
self._config = config
self._host = os.getenv("TUGRAPH_HOST", config.host)
self._port = int(os.getenv("TUGRAPH_PORT", config.port))
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
self._summary_enabled = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.summary_enabled
)
self._edge_label = (
os.getenv("TUGRAPH_EDGE_TYPE", "relation") or config.edge_type
)
self.edge_name_key = (
os.getenv("TUGRAPH_EDGE_NAME_KEY", "label") or config.edge_name_key
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
)
self._graph_name = config.name
self._vertex_type = os.getenv("TUGRAPH_VERTEX_TYPE", config.vertex_type)
self._edge_type = os.getenv("TUGRAPH_EDGE_TYPE", config.edge_type)
self.conn = TuGraphConnector.from_uri_db(
host=self._host,
port=self._port,
@@ -72,35 +81,197 @@ class TuGraphStore(GraphStoreBase):
pwd=self._password,
db_name=config.name,
)
self.conn.create_graph(graph_name=config.name)
self._create_graph(config.name)
def get_vertex_type(self) -> str:
"""Get the vertex type."""
return self._vertex_type
def get_edge_type(self) -> str:
"""Get the edge type."""
return self._edge_type
def _create_graph(self, graph_name: str):
self.conn.create_graph(graph_name=graph_name)
self._create_schema()
if self._summary_enabled:
self._upload_plugin()
def _check_label(self, elem_type: str):
result = self.conn.get_table_names()
if elem_type == "vertex":
return self._node_label in result["vertex_tables"]
return self._vertex_type in result["vertex_tables"]
if elem_type == "edge":
return self._edge_label in result["edge_tables"]
return self._edge_type in result["edge_tables"]
def _add_vertex_index(self, field_name):
gql = f"CALL db.addIndex('{self._vertex_type}', '{field_name}', false)"
self.conn.run(gql)
def _upload_plugin(self):
gql = "CALL db.plugin.listPlugin('CPP','v1')"
result = self.conn.run(gql)
result_names = [
json.loads(record["plugin_description"])["name"] for record in result
]
missing_plugins = [
name for name in self._plugin_names if name not in result_names
]
if len(missing_plugins):
for name in missing_plugins:
try:
from dbgpt_tugraph_plugins import ( # type: ignore # noqa
get_plugin_binary_path,
)
except ImportError:
logger.error(
"dbgpt-tugraph-plugins is not installed, "
"pip install dbgpt-tugraph-plugins==0.1.0rc1 -U -i "
"https://pypi.org/simple"
)
plugin_path = get_plugin_binary_path("leiden")
with open(plugin_path, "rb") as f:
content = f.read()
content = base64.b64encode(content).decode()
gql = (
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', "
"'SO', '{name} Plugin', false, 'v1')"
)
self.conn.run(gql)
def _create_schema(self):
if not self._check_label("vertex"):
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._node_label}', "
f"'id', ['id',string,false])"
)
self.conn.run(create_vertex_gql)
if self._summary_enabled:
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._vertex_type}', "
f"'id', ['id',string,false],"
f"['name',string,false],"
f"['_document_id',string,true],"
f"['_chunk_id',string,true],"
f"['_community_id',string,true],"
f"['description',string,true])"
)
self.conn.run(create_vertex_gql)
self._add_vertex_index("_community_id")
else:
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._vertex_type}', "
f"'id', ['id',string,false],"
f"['name',string,false])"
)
self.conn.run(create_vertex_gql)
if not self._check_label("edge"):
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_label}', '[["{self._node_label}",
"{self._node_label}"]]', ["id",STRING,false])"""
'edge', '{self._edge_type}',
'[["{self._vertex_type}",
"{self._vertex_type}"]]',
["id",STRING,false],
["name",STRING,false])"""
if self._summary_enabled:
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_type}',
'[["{self._vertex_type}",
"{self._vertex_type}"]]',
["id",STRING,false],
["name",STRING,false],
["description",STRING,true])"""
self.conn.run(create_edge_gql)
def _format_query_data(self, data, white_prop_list: List[str]):
nodes_list = []
rels_list: List[Any] = []
_white_list = white_prop_list
from neo4j import graph
def get_filtered_properties(properties, white_list):
return {
key: value
for key, value in properties.items()
if (not key.startswith("_") and key not in ["id", "name"])
or key in white_list
}
def process_node(node: graph.Node):
node_id = node._properties.get("id")
node_name = node._properties.get("name")
node_properties = get_filtered_properties(node._properties, _white_list)
nodes_list.append(
{"id": node_id, "name": node_name, "properties": node_properties}
)
def process_relationship(rel: graph.Relationship):
name = rel._properties.get("name", "")
rel_nodes = rel.nodes
src_id = rel_nodes[0]._properties.get("id")
dst_id = rel_nodes[1]._properties.get("id")
for node in rel_nodes:
process_node(node)
edge_properties = get_filtered_properties(rel._properties, _white_list)
if not any(
existing_edge.get("name") == name
and existing_edge.get("src_id") == src_id
and existing_edge.get("dst_id") == dst_id
for existing_edge in rels_list
):
rels_list.append(
{
"src_id": src_id,
"dst_id": dst_id,
"name": name,
"properties": edge_properties,
}
)
def process_path(path: graph.Path):
for rel in path.relationships:
process_relationship(rel)
def process_other(value):
if not any(
existing_node.get("id") == "json_node" for existing_node in nodes_list
):
nodes_list.append(
{
"id": "json_node",
"name": "json_node",
"properties": {"description": value},
}
)
for record in data:
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
process_node(value)
elif isinstance(value, graph.Relationship):
process_relationship(value)
elif isinstance(value, graph.Path):
process_path(value)
else:
process_other(value)
nodes = [
Vertex(node["id"], node["name"], **node["properties"])
for node in nodes_list
]
rels = [
Edge(edge["src_id"], edge["dst_id"], edge["name"], **edge["properties"])
for edge in rels_list
]
return {"nodes": nodes, "edges": rels}
def get_config(self):
"""Get the graph store config."""
return self._config
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
query = (
f"MATCH (n1:{self._node_label})-[r]->(n2:{self._node_label}) "
f"MATCH (n1:{self._vertex_type})-[r]->(n2:{self._vertex_type}) "
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
)
data = self.conn.run(query)
@@ -117,16 +288,83 @@ class TuGraphStore(GraphStoreBase):
rel_escaped = escape_quotes(rel)
obj_escaped = escape_quotes(obj)
subj_query = f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
obj_query = f"MERGE (n1:{self._node_label} {{id:'{obj_escaped}'}})"
rel_query = (
f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
f"-[r:{self._edge_label} {{id:'{rel_escaped}'}}]->"
f"(n2:{self._node_label} {{id:'{obj_escaped}'}})"
node_query = f"""CALL db.upsertVertex(
'{self._vertex_type}',
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
edge_query = f"""CALL db.upsertEdge(
'{self._edge_type}',
{{type:"{self._vertex_type}",key:"sid"}},
{{type:"{self._vertex_type}", key:"tid"}},
[{{sid:"{subj_escaped}",
tid: "{obj_escaped}",
id:"{rel_escaped}",
name: "{rel_escaped}"}}])"""
self.conn.run(query=node_query)
self.conn.run(query=edge_query)
def insert_graph(self, graph: Graph) -> None:
"""Add graph."""
def escape_quotes(value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")
nodes: Iterator[Vertex] = graph.vertices()
edges: Iterator[Edge] = graph.edges()
node_list = []
edge_list = []
def parser(node_list):
formatted_nodes = [
"{"
+ ", ".join(
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
for k, v in node.items()
)
+ "}"
for node in node_list
]
return f"""{', '.join(formatted_nodes)}"""
for node in nodes:
node_list.append(
{
"id": escape_quotes(node.vid),
"name": escape_quotes(node.name),
"description": escape_quotes(node.get_prop("description")) or "",
"_document_id": "0",
"_chunk_id": "0",
"_community_id": "0",
}
)
node_query = (
f"""CALL db.upsertVertex("{self._vertex_type}", [{parser(node_list)}])"""
)
self.conn.run(query=subj_query)
self.conn.run(query=obj_query)
self.conn.run(query=rel_query)
for edge in edges:
edge_list.append(
{
"sid": escape_quotes(edge.sid),
"tid": escape_quotes(edge.tid),
"id": escape_quotes(edge.name),
"name": escape_quotes(edge.name),
"description": escape_quotes(edge.get_prop("description")),
}
)
edge_query = f"""CALL db.upsertEdge(
"{self._edge_type}",
{{type:"{self._vertex_type}", key:"sid"}},
{{type:"{self._vertex_type}", key:"tid"}},
[{parser(edge_list)}])"""
self.conn.run(query=node_query)
self.conn.run(query=edge_query)
def truncate(self):
"""Truncate Graph."""
gql = "MATCH (n) DELETE n"
self.conn.run(gql)
def drop(self):
"""Delete Graph."""
@@ -135,9 +373,9 @@ class TuGraphStore(GraphStoreBase):
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
del_query = (
f"MATCH (n1:{self._node_label} {{id:'{sub}'}})"
f"-[r:{self._edge_label} {{id:'{rel}'}}]->"
f"(n2:{self._node_label} {{id:'{obj}'}}) DELETE n1,n2,r"
f"MATCH (n1:{self._vertex_type} {{id:'{sub}'}})"
f"-[r:{self._edge_type} {{id:'{rel}'}}]->"
f"(n2:{self._vertex_type} {{id:'{obj}'}}) DELETE n1,n2,r"
)
self.conn.run(query=del_query)
@@ -148,11 +386,20 @@ class TuGraphStore(GraphStoreBase):
schema = data[0]["schema"]
return schema
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
if not limit:
raise Exception("limit must be set")
return self.query(f"MATCH (n)-[r]-(m) RETURN n,m,r LIMIT {limit}")
graph_result = self.query(
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
white_list=["_community_id"],
)
all_graph = MemoryGraph()
for vertex in graph_result.vertices():
all_graph.upsert_vertex(vertex)
for edge in graph_result.edges():
all_graph.append_edge(edge)
return all_graph
def explore(
self,
@@ -161,8 +408,11 @@ class TuGraphStore(GraphStoreBase):
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
) -> Graph:
"""Explore the graph from given subjects up to a depth."""
if not subs:
return MemoryGraph()
if fan is not None:
raise ValueError("Fan functionality is not supported at this time.")
else:
@@ -173,67 +423,88 @@ class TuGraphStore(GraphStoreBase):
limit_string = f"LIMIT {limit}"
if limit is None:
limit_string = ""
if direct.name == "OUT":
rel = f"-[r:{self._edge_type}*{depth_string}]->"
elif direct.name == "IN":
rel = f"<-[r:{self._edge_type}*{depth_string}]-"
else:
rel = f"-[r:{self._edge_type}*{depth_string}]-"
query = (
f"MATCH p=(n:{self._node_label})"
f"-[r:{self._edge_label}*{depth_string}]-(m:{self._node_label}) "
f"MATCH p=(n:{self._vertex_type})"
f"{rel}(m:{self._vertex_type}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}"
)
return self.query(query)
def query(self, query: str, **args) -> MemoryGraph:
"""Execute a query on graph."""
def _format_paths(paths):
formatted_paths = []
for path in paths:
formatted_path = []
nodes = list(path["p"].nodes)
rels = list(path["p"].relationships)
for i in range(len(nodes)):
formatted_path.append(nodes[i]._properties["id"])
if i < len(rels):
formatted_path.append(rels[i]._properties["id"])
formatted_paths.append(formatted_path)
return formatted_paths
def _format_query_data(data):
node_ids_set = set()
rels_set = set()
from neo4j import graph
for record in data:
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
node_ids_set.add(node_id)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
rels_set.add((src_id, dst_id, prop_id))
elif isinstance(value, graph.Path):
formatted_paths = _format_paths(data)
for path in formatted_paths:
for i in range(0, len(path), 2):
node_ids_set.add(path[i])
if i + 2 < len(path):
rels_set.add((path[i], path[i + 2], path[i + 1]))
nodes = [Vertex(node_id) for node_id in node_ids_set]
rels = [
Edge(src_id, dst_id, label=prop_id)
for (src_id, dst_id, prop_id) in rels_set
]
return {"nodes": nodes, "edges": rels}
result = self.conn.run(query=query)
graph = _format_query_data(result)
white_list = args.get("white_list", [])
graph = self._format_query_data(result, white_list)
mg = MemoryGraph()
for vertex in graph["nodes"]:
mg.upsert_vertex(vertex)
for edge in graph["edges"]:
mg.append_edge(edge)
return mg
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute a stream query."""
from neo4j import graph
for record in self.conn.run_stream(query):
mg = MemoryGraph()
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
description = value._properties["description"]
vertex = Vertex(node_id, name=node_id, description=description)
mg.upsert_vertex(vertex)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
description = value._properties["description"]
edge = Edge(src_id, dst_id, name=prop_id, description=description)
mg.append_edge(edge)
elif isinstance(value, graph.Path):
nodes = list(record["p"].nodes)
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append(
{
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
}
)
if i < len(rels):
formatted_path.append(
{
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
}
)
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(
formatted_path[i]["id"],
name=formatted_path[i]["id"],
description=formatted_path[i]["description"],
)
)
if i + 2 < len(formatted_path):
mg.append_edge(
Edge(
formatted_path[i]["id"],
formatted_path[i + 2]["id"],
name=formatted_path[i + 1]["id"],
description=formatted_path[i + 1]["description"],
)
)
else:
vertex = Vertex("json_node", name="json_node", description=value)
mg.upsert_vertex(vertex)
yield mg