mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 21:12:13 +00:00
✨ feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)
Co-authored-by: Florian <fanzhidongyzby@163.com> Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Graph store base class."""
|
||||
"""Graph definition."""
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -24,9 +24,15 @@ class Direction(Enum):
|
||||
class Elem(ABC):
|
||||
"""Elem class."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""Initialize Elem."""
|
||||
self._props = {}
|
||||
self._name = name
|
||||
self._props: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the edge label."""
|
||||
return self._name or ""
|
||||
|
||||
@property
|
||||
def props(self) -> Dict[str, Any]:
|
||||
@@ -46,14 +52,17 @@ class Elem(ABC):
|
||||
self._props.pop(key, None)
|
||||
|
||||
def has_props(self, **props):
|
||||
"""Check if the element has the specified properties with the given values."""
|
||||
"""Check all key-value pairs exist."""
|
||||
return all(self._props.get(k) == v for k, v in props.items())
|
||||
|
||||
@abstractmethod
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
def format(self) -> str:
|
||||
"""Format properties into a string."""
|
||||
if len(self._props) == 1:
|
||||
return str(next(iter(self._props.values())))
|
||||
|
||||
formatted_props = [
|
||||
f"{k}:{json.dumps(v)}" for k, v in self._props.items() if k != label_key
|
||||
f"{k}:{json.dumps(v, ensure_ascii=False)}" for k, v in self._props.items()
|
||||
]
|
||||
return f"{{{';'.join(formatted_props)}}}"
|
||||
|
||||
@@ -61,9 +70,9 @@ class Elem(ABC):
|
||||
class Vertex(Elem):
|
||||
"""Vertex class."""
|
||||
|
||||
def __init__(self, vid: str, **props):
|
||||
def __init__(self, vid: str, name: Optional[str] = None, **props):
|
||||
"""Initialize Vertex."""
|
||||
super().__init__()
|
||||
super().__init__(name)
|
||||
self._vid = vid
|
||||
for k, v in props.items():
|
||||
self.set_prop(k, v)
|
||||
@@ -73,26 +82,43 @@ class Vertex(Elem):
|
||||
"""Return the vertex ID."""
|
||||
return self._vid
|
||||
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
"""Format vertex properties into a string."""
|
||||
label = self.get_prop(label_key) if label_key else self._vid
|
||||
props_str = super().format(label_key)
|
||||
if props_str == "{}":
|
||||
return f"({label})"
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the vertex name."""
|
||||
return super().name or self._vid
|
||||
|
||||
def format(self, concise: bool = False):
|
||||
"""Format vertex into a string."""
|
||||
name = self._name or self._vid
|
||||
if concise:
|
||||
return f"({name})"
|
||||
|
||||
if self._props:
|
||||
return f"({name}:{super().format()})"
|
||||
else:
|
||||
return f"({label}:{props_str})"
|
||||
return f"({name})"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the vertex ID as its string representation."""
|
||||
return f"({self._vid})"
|
||||
|
||||
|
||||
class IdVertex(Vertex):
|
||||
"""IdVertex class."""
|
||||
|
||||
def __init__(self, vid: str):
|
||||
"""Initialize Idvertex."""
|
||||
super().__init__(vid)
|
||||
|
||||
|
||||
class Edge(Elem):
|
||||
"""Edge class."""
|
||||
|
||||
def __init__(self, sid: str, tid: str, **props):
|
||||
def __init__(self, sid: str, tid: str, name: str, **props):
|
||||
"""Initialize Edge."""
|
||||
super().__init__()
|
||||
assert name, "Edge name is required"
|
||||
|
||||
super().__init__(name)
|
||||
self._sid = sid
|
||||
self._tid = tid
|
||||
for k, v in props.items():
|
||||
@@ -117,23 +143,20 @@ class Edge(Elem):
|
||||
else:
|
||||
raise ValueError(f"Get nid of {vid} on {self} failed")
|
||||
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
def format(self):
|
||||
"""Format the edge properties into a string."""
|
||||
label = self.get_prop(label_key) if label_key else ""
|
||||
props_str = super().format(label_key)
|
||||
if props_str == "{}":
|
||||
return f"-[{label}]->" if label else "->"
|
||||
if self._props:
|
||||
return f"-[{self._name}:{super().format()}]->"
|
||||
else:
|
||||
return f"-[{label}:{props_str}]->" if label else f"-[{props_str}]->"
|
||||
return f"-[{self._name}]->"
|
||||
|
||||
def triplet(self, label_key: str) -> Tuple[str, str, str]:
|
||||
def triplet(self) -> Tuple[str, str, str]:
|
||||
"""Return a triplet."""
|
||||
assert label_key, "label key is needed"
|
||||
return self._sid, str(self.get_prop(label_key)), self._tid
|
||||
return self.sid, self.name, self.tid
|
||||
|
||||
def __str__(self):
|
||||
"""Return the edge '(sid)->(tid)'."""
|
||||
return f"({self._sid})->({self._tid})"
|
||||
return f"({self._sid})-[{self._name}]->({self._tid})"
|
||||
|
||||
|
||||
class Graph(ABC):
|
||||
@@ -177,8 +200,8 @@ class Graph(ABC):
|
||||
"""Delete vertices and their neighbor edges."""
|
||||
|
||||
@abstractmethod
|
||||
def del_edges(self, sid: str, tid: str, **props):
|
||||
"""Delete edges(sid -> tid) matches props."""
|
||||
def del_edges(self, sid: str, tid: str, name: str, **props):
|
||||
"""Delete edges(sid -[name]-> tid) matches props."""
|
||||
|
||||
@abstractmethod
|
||||
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
|
||||
@@ -203,19 +226,19 @@ class Graph(ABC):
|
||||
def format(self) -> str:
|
||||
"""Format graph data to string."""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
|
||||
|
||||
class MemoryGraph(Graph):
|
||||
"""Graph class."""
|
||||
|
||||
def __init__(self, vertex_label: Optional[str] = None, edge_label: str = "label"):
|
||||
def __init__(self):
|
||||
"""Initialize MemoryGraph with vertex label and edge label."""
|
||||
assert edge_label, "Edge label is needed"
|
||||
|
||||
# metadata
|
||||
self._vertex_label = vertex_label
|
||||
self._edge_label = edge_label
|
||||
self._vertex_prop_keys = {vertex_label} if vertex_label else set()
|
||||
self._edge_prop_keys = {edge_label}
|
||||
self._vertex_prop_keys = set()
|
||||
self._edge_prop_keys = set()
|
||||
self._edge_count = 0
|
||||
|
||||
# init vertices, out edges, in edges index
|
||||
@@ -223,26 +246,6 @@ class MemoryGraph(Graph):
|
||||
self._oes: Any = defaultdict(lambda: defaultdict(set))
|
||||
self._ies: Any = defaultdict(lambda: defaultdict(set))
|
||||
|
||||
@property
|
||||
def vertex_label(self):
|
||||
"""Return the label for vertices."""
|
||||
return self._vertex_label
|
||||
|
||||
@property
|
||||
def edge_label(self):
|
||||
"""Return the label for edges."""
|
||||
return self._edge_label
|
||||
|
||||
@property
|
||||
def vertex_prop_keys(self):
|
||||
"""Return a set of property keys for vertices."""
|
||||
return self._vertex_prop_keys
|
||||
|
||||
@property
|
||||
def edge_prop_keys(self):
|
||||
"""Return a set of property keys for edges."""
|
||||
return self._edge_prop_keys
|
||||
|
||||
@property
|
||||
def vertex_count(self):
|
||||
"""Return the number of vertices in the graph."""
|
||||
@@ -256,7 +259,10 @@ class MemoryGraph(Graph):
|
||||
def upsert_vertex(self, vertex: Vertex):
|
||||
"""Insert or update a vertex based on its ID."""
|
||||
if vertex.vid in self._vs:
|
||||
self._vs[vertex.vid].props.update(vertex.props)
|
||||
if isinstance(self._vs[vertex.vid], IdVertex):
|
||||
self._vs[vertex.vid] = vertex
|
||||
else:
|
||||
self._vs[vertex.vid].props.update(vertex.props)
|
||||
else:
|
||||
self._vs[vertex.vid] = vertex
|
||||
|
||||
@@ -265,9 +271,6 @@ class MemoryGraph(Graph):
|
||||
|
||||
def append_edge(self, edge: Edge):
|
||||
"""Append an edge if it doesn't exist; requires edge label."""
|
||||
if self.edge_label not in edge.props.keys():
|
||||
raise ValueError(f"Edge prop '{self.edge_label}' is needed")
|
||||
|
||||
sid = edge.sid
|
||||
tid = edge.tid
|
||||
|
||||
@@ -275,8 +278,8 @@ class MemoryGraph(Graph):
|
||||
return False
|
||||
|
||||
# init vertex index
|
||||
self._vs.setdefault(sid, Vertex(sid))
|
||||
self._vs.setdefault(tid, Vertex(tid))
|
||||
self._vs.setdefault(sid, IdVertex(sid))
|
||||
self._vs.setdefault(tid, IdVertex(tid))
|
||||
|
||||
# update edge index
|
||||
self._oes[sid][tid].add(edge)
|
||||
@@ -346,18 +349,19 @@ class MemoryGraph(Graph):
|
||||
self.del_neighbor_edges(vid, Direction.BOTH)
|
||||
self._vs.pop(vid, None)
|
||||
|
||||
def del_edges(self, sid: str, tid: str, **props):
|
||||
def del_edges(self, sid: str, tid: str, name: str, **props):
|
||||
"""Delete edges."""
|
||||
old_edge_cnt = len(self._oes[sid][tid])
|
||||
|
||||
if not props:
|
||||
self._edge_count -= old_edge_cnt
|
||||
self._oes[sid].pop(tid, None)
|
||||
self._ies[tid].pop(sid, None)
|
||||
return
|
||||
|
||||
def remove_matches(es):
|
||||
return set(filter(lambda e: not e.has_props(**props), es))
|
||||
return set(
|
||||
filter(
|
||||
lambda e: not (
|
||||
(name == e.name if name else True) and e.has_props(**props)
|
||||
),
|
||||
es,
|
||||
)
|
||||
)
|
||||
|
||||
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
|
||||
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
|
||||
@@ -439,12 +443,10 @@ class MemoryGraph(Graph):
|
||||
"schema": [
|
||||
{
|
||||
"type": "VERTEX",
|
||||
"label": f"{self._vertex_label}",
|
||||
"properties": [{"name": k} for k in self._vertex_prop_keys],
|
||||
},
|
||||
{
|
||||
"type": "EDGE",
|
||||
"label": f"{self._edge_label}",
|
||||
"properties": [{"name": k} for k in self._edge_prop_keys],
|
||||
},
|
||||
]
|
||||
@@ -452,14 +454,30 @@ class MemoryGraph(Graph):
|
||||
|
||||
def format(self) -> str:
|
||||
"""Format graph to string."""
|
||||
vs_str = "\n".join(v.format(self.vertex_label) for v in self.vertices())
|
||||
vs_str = "\n".join(v.format() for v in self.vertices())
|
||||
es_str = "\n".join(
|
||||
f"{self.get_vertex(e.sid).format(self.vertex_label)}"
|
||||
f"{e.format(self.edge_label)}"
|
||||
f"{self.get_vertex(e.tid).format(self.vertex_label)}"
|
||||
f"{self.get_vertex(e.sid).format(concise=True)}"
|
||||
f"{e.format()}"
|
||||
f"{self.get_vertex(e.tid).format(concise=True)}"
|
||||
for e in self.edges()
|
||||
)
|
||||
return f"Vertices:\n{vs_str}\n\nEdges:\n{es_str}"
|
||||
return (
|
||||
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
|
||||
if (vs_str or es_str)
|
||||
else ""
|
||||
)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
# clean metadata
|
||||
self._vertex_prop_keys.clear()
|
||||
self._edge_prop_keys.clear()
|
||||
self._edge_count = 0
|
||||
|
||||
# clean data and index
|
||||
self._vs.clear()
|
||||
self._oes.clear()
|
||||
self._ies.clear()
|
||||
|
||||
def graphviz(self, name="g"):
|
||||
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
|
||||
@@ -468,7 +486,7 @@ class MemoryGraph(Graph):
|
||||
g.add_node(vertex.vid)
|
||||
|
||||
for edge in self.edges():
|
||||
triplet = edge.triplet(self.edge_label)
|
||||
triplet = edge.triplet()
|
||||
g.add_edge(triplet[0], triplet[2], label=triplet[1])
|
||||
|
||||
digraph = nx.nx_agraph.to_agraph(g).to_string()
|
||||
|
Reference in New Issue
Block a user