feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)

Co-authored-by: Florian <fanzhidongyzby@163.com>
Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
M1n9X
2024-08-30 21:59:44 +08:00
committed by GitHub
parent 471689ba20
commit 759f7d99cc
59 changed files with 29316 additions and 411 deletions

View File

@@ -1,4 +1,4 @@
"""Graph store base class."""
"""Graph definition."""
import itertools
import json
import logging
@@ -24,9 +24,15 @@ class Direction(Enum):
class Elem(ABC):
"""Elem class."""
def __init__(self):
def __init__(self, name: Optional[str] = None):
"""Initialize Elem."""
self._props = {}
self._name = name
self._props: Dict[str, Any] = {}
@property
def name(self) -> str:
"""Return the edge label."""
return self._name or ""
@property
def props(self) -> Dict[str, Any]:
@@ -46,14 +52,17 @@ class Elem(ABC):
self._props.pop(key, None)
def has_props(self, **props):
"""Check if the element has the specified properties with the given values."""
"""Check all key-value pairs exist."""
return all(self._props.get(k) == v for k, v in props.items())
@abstractmethod
def format(self, label_key: Optional[str] = None):
def format(self) -> str:
"""Format properties into a string."""
if len(self._props) == 1:
return str(next(iter(self._props.values())))
formatted_props = [
f"{k}:{json.dumps(v)}" for k, v in self._props.items() if k != label_key
f"{k}:{json.dumps(v, ensure_ascii=False)}" for k, v in self._props.items()
]
return f"{{{';'.join(formatted_props)}}}"
@@ -61,9 +70,9 @@ class Elem(ABC):
class Vertex(Elem):
"""Vertex class."""
def __init__(self, vid: str, **props):
def __init__(self, vid: str, name: Optional[str] = None, **props):
"""Initialize Vertex."""
super().__init__()
super().__init__(name)
self._vid = vid
for k, v in props.items():
self.set_prop(k, v)
@@ -73,26 +82,43 @@ class Vertex(Elem):
"""Return the vertex ID."""
return self._vid
def format(self, label_key: Optional[str] = None):
"""Format vertex properties into a string."""
label = self.get_prop(label_key) if label_key else self._vid
props_str = super().format(label_key)
if props_str == "{}":
return f"({label})"
@property
def name(self) -> str:
"""Return the vertex name."""
return super().name or self._vid
def format(self, concise: bool = False):
"""Format vertex into a string."""
name = self._name or self._vid
if concise:
return f"({name})"
if self._props:
return f"({name}:{super().format()})"
else:
return f"({label}:{props_str})"
return f"({name})"
def __str__(self):
"""Return the vertex ID as its string representation."""
return f"({self._vid})"
class IdVertex(Vertex):
"""IdVertex class."""
def __init__(self, vid: str):
"""Initialize Idvertex."""
super().__init__(vid)
class Edge(Elem):
"""Edge class."""
def __init__(self, sid: str, tid: str, **props):
def __init__(self, sid: str, tid: str, name: str, **props):
"""Initialize Edge."""
super().__init__()
assert name, "Edge name is required"
super().__init__(name)
self._sid = sid
self._tid = tid
for k, v in props.items():
@@ -117,23 +143,20 @@ class Edge(Elem):
else:
raise ValueError(f"Get nid of {vid} on {self} failed")
def format(self, label_key: Optional[str] = None):
def format(self):
"""Format the edge properties into a string."""
label = self.get_prop(label_key) if label_key else ""
props_str = super().format(label_key)
if props_str == "{}":
return f"-[{label}]->" if label else "->"
if self._props:
return f"-[{self._name}:{super().format()}]->"
else:
return f"-[{label}:{props_str}]->" if label else f"-[{props_str}]->"
return f"-[{self._name}]->"
def triplet(self, label_key: str) -> Tuple[str, str, str]:
def triplet(self) -> Tuple[str, str, str]:
"""Return a triplet."""
assert label_key, "label key is needed"
return self._sid, str(self.get_prop(label_key)), self._tid
return self.sid, self.name, self.tid
def __str__(self):
"""Return the edge '(sid)->(tid)'."""
return f"({self._sid})->({self._tid})"
return f"({self._sid})-[{self._name}]->({self._tid})"
class Graph(ABC):
@@ -177,8 +200,8 @@ class Graph(ABC):
"""Delete vertices and their neighbor edges."""
@abstractmethod
def del_edges(self, sid: str, tid: str, **props):
"""Delete edges(sid -> tid) matches props."""
def del_edges(self, sid: str, tid: str, name: str, **props):
"""Delete edges(sid -[name]-> tid) matches props."""
@abstractmethod
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
@@ -203,19 +226,19 @@ class Graph(ABC):
def format(self) -> str:
"""Format graph data to string."""
@abstractmethod
def truncate(self):
"""Truncate graph."""
class MemoryGraph(Graph):
"""Graph class."""
def __init__(self, vertex_label: Optional[str] = None, edge_label: str = "label"):
def __init__(self):
"""Initialize MemoryGraph with vertex label and edge label."""
assert edge_label, "Edge label is needed"
# metadata
self._vertex_label = vertex_label
self._edge_label = edge_label
self._vertex_prop_keys = {vertex_label} if vertex_label else set()
self._edge_prop_keys = {edge_label}
self._vertex_prop_keys = set()
self._edge_prop_keys = set()
self._edge_count = 0
# init vertices, out edges, in edges index
@@ -223,26 +246,6 @@ class MemoryGraph(Graph):
self._oes: Any = defaultdict(lambda: defaultdict(set))
self._ies: Any = defaultdict(lambda: defaultdict(set))
@property
def vertex_label(self):
"""Return the label for vertices."""
return self._vertex_label
@property
def edge_label(self):
"""Return the label for edges."""
return self._edge_label
@property
def vertex_prop_keys(self):
"""Return a set of property keys for vertices."""
return self._vertex_prop_keys
@property
def edge_prop_keys(self):
"""Return a set of property keys for edges."""
return self._edge_prop_keys
@property
def vertex_count(self):
"""Return the number of vertices in the graph."""
@@ -256,7 +259,10 @@ class MemoryGraph(Graph):
def upsert_vertex(self, vertex: Vertex):
"""Insert or update a vertex based on its ID."""
if vertex.vid in self._vs:
self._vs[vertex.vid].props.update(vertex.props)
if isinstance(self._vs[vertex.vid], IdVertex):
self._vs[vertex.vid] = vertex
else:
self._vs[vertex.vid].props.update(vertex.props)
else:
self._vs[vertex.vid] = vertex
@@ -265,9 +271,6 @@ class MemoryGraph(Graph):
def append_edge(self, edge: Edge):
"""Append an edge if it doesn't exist; requires edge label."""
if self.edge_label not in edge.props.keys():
raise ValueError(f"Edge prop '{self.edge_label}' is needed")
sid = edge.sid
tid = edge.tid
@@ -275,8 +278,8 @@ class MemoryGraph(Graph):
return False
# init vertex index
self._vs.setdefault(sid, Vertex(sid))
self._vs.setdefault(tid, Vertex(tid))
self._vs.setdefault(sid, IdVertex(sid))
self._vs.setdefault(tid, IdVertex(tid))
# update edge index
self._oes[sid][tid].add(edge)
@@ -346,18 +349,19 @@ class MemoryGraph(Graph):
self.del_neighbor_edges(vid, Direction.BOTH)
self._vs.pop(vid, None)
def del_edges(self, sid: str, tid: str, **props):
def del_edges(self, sid: str, tid: str, name: str, **props):
"""Delete edges."""
old_edge_cnt = len(self._oes[sid][tid])
if not props:
self._edge_count -= old_edge_cnt
self._oes[sid].pop(tid, None)
self._ies[tid].pop(sid, None)
return
def remove_matches(es):
return set(filter(lambda e: not e.has_props(**props), es))
return set(
filter(
lambda e: not (
(name == e.name if name else True) and e.has_props(**props)
),
es,
)
)
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
@@ -439,12 +443,10 @@ class MemoryGraph(Graph):
"schema": [
{
"type": "VERTEX",
"label": f"{self._vertex_label}",
"properties": [{"name": k} for k in self._vertex_prop_keys],
},
{
"type": "EDGE",
"label": f"{self._edge_label}",
"properties": [{"name": k} for k in self._edge_prop_keys],
},
]
@@ -452,14 +454,30 @@ class MemoryGraph(Graph):
def format(self) -> str:
"""Format graph to string."""
vs_str = "\n".join(v.format(self.vertex_label) for v in self.vertices())
vs_str = "\n".join(v.format() for v in self.vertices())
es_str = "\n".join(
f"{self.get_vertex(e.sid).format(self.vertex_label)}"
f"{e.format(self.edge_label)}"
f"{self.get_vertex(e.tid).format(self.vertex_label)}"
f"{self.get_vertex(e.sid).format(concise=True)}"
f"{e.format()}"
f"{self.get_vertex(e.tid).format(concise=True)}"
for e in self.edges()
)
return f"Vertices:\n{vs_str}\n\nEdges:\n{es_str}"
return (
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
if (vs_str or es_str)
else ""
)
def truncate(self):
"""Truncate graph."""
# clean metadata
self._vertex_prop_keys.clear()
self._edge_prop_keys.clear()
self._edge_count = 0
# clean data and index
self._vs.clear()
self._oes.clear()
self._ies.clear()
def graphviz(self, name="g"):
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
@@ -468,7 +486,7 @@ class MemoryGraph(Graph):
g.add_node(vertex.vid)
for edge in self.edges():
triplet = edge.triplet(self.edge_label)
triplet = edge.triplet()
g.add_edge(triplet[0], triplet[2], label=triplet[1])
digraph = nx.nx_agraph.to_agraph(g).to_string()