DB-GPT/dbgpt/storage/graph_store/graph.py
M1n9X 759f7d99cc
feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)
Co-authored-by: Florian <fanzhidongyzby@163.com>
Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
2024-08-30 21:59:44 +08:00

496 lines
14 KiB
Python

"""Graph definition."""
import itertools
import json
import logging
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
import networkx as nx
logger = logging.getLogger(__name__)
class Direction(Enum):
"""Direction class."""
OUT = 0
IN = 1
BOTH = 2
class Elem(ABC):
"""Elem class."""
def __init__(self, name: Optional[str] = None):
"""Initialize Elem."""
self._name = name
self._props: Dict[str, Any] = {}
@property
def name(self) -> str:
"""Return the edge label."""
return self._name or ""
@property
def props(self) -> Dict[str, Any]:
"""Get all the properties of Elem."""
return self._props
def set_prop(self, key: str, value: Any):
"""Set a property of ELem."""
self._props[key] = value
def get_prop(self, key: str):
"""Get one of the properties of Elem."""
return self._props.get(key)
def del_prop(self, key: str):
"""Delete a property of ELem."""
self._props.pop(key, None)
def has_props(self, **props):
"""Check all key-value pairs exist."""
return all(self._props.get(k) == v for k, v in props.items())
@abstractmethod
def format(self) -> str:
"""Format properties into a string."""
if len(self._props) == 1:
return str(next(iter(self._props.values())))
formatted_props = [
f"{k}:{json.dumps(v, ensure_ascii=False)}" for k, v in self._props.items()
]
return f"{{{';'.join(formatted_props)}}}"
class Vertex(Elem):
"""Vertex class."""
def __init__(self, vid: str, name: Optional[str] = None, **props):
"""Initialize Vertex."""
super().__init__(name)
self._vid = vid
for k, v in props.items():
self.set_prop(k, v)
@property
def vid(self) -> str:
"""Return the vertex ID."""
return self._vid
@property
def name(self) -> str:
"""Return the vertex name."""
return super().name or self._vid
def format(self, concise: bool = False):
"""Format vertex into a string."""
name = self._name or self._vid
if concise:
return f"({name})"
if self._props:
return f"({name}:{super().format()})"
else:
return f"({name})"
def __str__(self):
"""Return the vertex ID as its string representation."""
return f"({self._vid})"
class IdVertex(Vertex):
"""IdVertex class."""
def __init__(self, vid: str):
"""Initialize Idvertex."""
super().__init__(vid)
class Edge(Elem):
"""Edge class."""
def __init__(self, sid: str, tid: str, name: str, **props):
"""Initialize Edge."""
assert name, "Edge name is required"
super().__init__(name)
self._sid = sid
self._tid = tid
for k, v in props.items():
self.set_prop(k, v)
@property
def sid(self) -> str:
"""Return the source vertex ID of the edge."""
return self._sid
@property
def tid(self) -> str:
"""Return the target vertex ID of the edge."""
return self._tid
def nid(self, vid):
"""Return neighbor id."""
if vid == self._sid:
return self._tid
elif vid == self._tid:
return self._sid
else:
raise ValueError(f"Get nid of {vid} on {self} failed")
def format(self):
"""Format the edge properties into a string."""
if self._props:
return f"-[{self._name}:{super().format()}]->"
else:
return f"-[{self._name}]->"
def triplet(self) -> Tuple[str, str, str]:
"""Return a triplet."""
return self.sid, self.name, self.tid
def __str__(self):
"""Return the edge '(sid)->(tid)'."""
return f"({self._sid})-[{self._name}]->({self._tid})"
class Graph(ABC):
"""Graph class."""
@abstractmethod
def upsert_vertex(self, vertex: Vertex):
"""Add a vertex."""
@abstractmethod
def append_edge(self, edge: Edge):
"""Add an edge."""
@abstractmethod
def has_vertex(self, vid: str) -> bool:
"""Check vertex exists."""
@abstractmethod
def get_vertex(self, vid: str) -> Vertex:
"""Get a vertex."""
@abstractmethod
def get_neighbor_edges(
self,
vid: str,
direction: Direction = Direction.OUT,
limit: Optional[int] = None,
) -> Iterator[Edge]:
"""Get neighbor edges."""
@abstractmethod
def vertices(self) -> Iterator[Vertex]:
"""Get vertex iterator."""
@abstractmethod
def edges(self) -> Iterator[Edge]:
"""Get edge iterator."""
@abstractmethod
def del_vertices(self, *vids: str):
"""Delete vertices and their neighbor edges."""
@abstractmethod
def del_edges(self, sid: str, tid: str, name: str, **props):
"""Delete edges(sid -[name]-> tid) matches props."""
@abstractmethod
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
"""Delete neighbor edges."""
@abstractmethod
def search(
self,
vids: List[str],
direct: Direction = Direction.OUT,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> "Graph":
"""Search on graph."""
@abstractmethod
def schema(self) -> Dict[str, Any]:
"""Get schema."""
@abstractmethod
def format(self) -> str:
"""Format graph data to string."""
@abstractmethod
def truncate(self):
"""Truncate graph."""
class MemoryGraph(Graph):
"""Graph class."""
def __init__(self):
"""Initialize MemoryGraph with vertex label and edge label."""
# metadata
self._vertex_prop_keys = set()
self._edge_prop_keys = set()
self._edge_count = 0
# init vertices, out edges, in edges index
self._vs: Any = defaultdict()
self._oes: Any = defaultdict(lambda: defaultdict(set))
self._ies: Any = defaultdict(lambda: defaultdict(set))
@property
def vertex_count(self):
"""Return the number of vertices in the graph."""
return len(self._vs)
@property
def edge_count(self):
"""Return the count of edges in the graph."""
return self._edge_count
def upsert_vertex(self, vertex: Vertex):
"""Insert or update a vertex based on its ID."""
if vertex.vid in self._vs:
if isinstance(self._vs[vertex.vid], IdVertex):
self._vs[vertex.vid] = vertex
else:
self._vs[vertex.vid].props.update(vertex.props)
else:
self._vs[vertex.vid] = vertex
# update metadata
self._vertex_prop_keys.update(vertex.props.keys())
def append_edge(self, edge: Edge):
"""Append an edge if it doesn't exist; requires edge label."""
sid = edge.sid
tid = edge.tid
if edge in self._oes[sid][tid]:
return False
# init vertex index
self._vs.setdefault(sid, IdVertex(sid))
self._vs.setdefault(tid, IdVertex(tid))
# update edge index
self._oes[sid][tid].add(edge)
self._ies[tid][sid].add(edge)
# update metadata
self._edge_prop_keys.update(edge.props.keys())
self._edge_count += 1
return True
def has_vertex(self, vid: str) -> bool:
"""Retrieve a vertex by ID."""
return vid in self._vs
def get_vertex(self, vid: str) -> Vertex:
"""Retrieve a vertex by ID."""
return self._vs[vid]
def get_neighbor_edges(
self,
vid: str,
direction: Direction = Direction.OUT,
limit: Optional[int] = None,
) -> Iterator[Edge]:
"""Get edges connected to a vertex by direction."""
if direction == Direction.OUT:
es = (e for es in self._oes[vid].values() for e in es)
elif direction == Direction.IN:
es = iter(e for es in self._ies[vid].values() for e in es)
elif direction == Direction.BOTH:
oes = (e for es in self._oes[vid].values() for e in es)
ies = (e for es in self._ies[vid].values() for e in es)
# merge
tuples = itertools.zip_longest(oes, ies)
es = (e for t in tuples for e in t if e is not None)
# distinct
seen = set()
# es = (e for e in es if e not in seen and not seen.add(e))
def unique_elements(elements):
for element in elements:
if element not in seen:
seen.add(element)
yield element
es = unique_elements(es)
else:
raise ValueError(f"Invalid direction: {direction}")
return itertools.islice(es, limit) if limit else es
def vertices(self) -> Iterator[Vertex]:
"""Return vertices."""
return iter(self._vs.values())
def edges(self) -> Iterator[Edge]:
"""Return edges."""
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
def del_vertices(self, *vids: str):
"""Delete specified vertices."""
for vid in vids:
self.del_neighbor_edges(vid, Direction.BOTH)
self._vs.pop(vid, None)
def del_edges(self, sid: str, tid: str, name: str, **props):
"""Delete edges."""
old_edge_cnt = len(self._oes[sid][tid])
def remove_matches(es):
return set(
filter(
lambda e: not (
(name == e.name if name else True) and e.has_props(**props)
),
es,
)
)
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
self._edge_count -= old_edge_cnt - len(self._oes[sid][tid])
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
"""Delete all neighbor edges."""
def del_index(idx, i_idx):
for nid in idx[vid].keys():
self._edge_count -= len(i_idx[nid][vid])
i_idx[nid].pop(vid, None)
idx.pop(vid, None)
if direction in [Direction.OUT, Direction.BOTH]:
del_index(self._oes, self._ies)
if direction in [Direction.IN, Direction.BOTH]:
del_index(self._ies, self._oes)
def search(
self,
vids: List[str],
direct: Direction = Direction.OUT,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> "MemoryGraph":
"""Search the graph from a vertex with specified parameters."""
subgraph = MemoryGraph()
for vid in vids:
self.__search(vid, direct, depth, fan, limit, 0, set(), subgraph)
return subgraph
def __search(
self,
vid: str,
direct: Direction,
depth: Optional[int],
fan: Optional[int],
limit: Optional[int],
_depth: int,
_visited: Set,
_subgraph: "MemoryGraph",
):
if vid in _visited or depth and _depth >= depth:
return
# visit vertex
if not self.has_vertex(vid):
return
_subgraph.upsert_vertex(self.get_vertex(vid))
_visited.add(vid)
# visit edges
nids = set()
for edge in self.get_neighbor_edges(vid, direct, fan):
if limit and _subgraph.edge_count >= limit:
return
# append edge success then visit new vertex
if _subgraph.append_edge(edge):
nid = edge.nid(vid)
if nid not in _visited:
nids.add(nid)
# next hop
for nid in nids:
self.__search(
nid, direct, depth, fan, limit, _depth + 1, _visited, _subgraph
)
def schema(self) -> Dict[str, Any]:
"""Return schema."""
return {
"schema": [
{
"type": "VERTEX",
"properties": [{"name": k} for k in self._vertex_prop_keys],
},
{
"type": "EDGE",
"properties": [{"name": k} for k in self._edge_prop_keys],
},
]
}
def format(self) -> str:
"""Format graph to string."""
vs_str = "\n".join(v.format() for v in self.vertices())
es_str = "\n".join(
f"{self.get_vertex(e.sid).format(concise=True)}"
f"{e.format()}"
f"{self.get_vertex(e.tid).format(concise=True)}"
for e in self.edges()
)
return (
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
if (vs_str or es_str)
else ""
)
def truncate(self):
"""Truncate graph."""
# clean metadata
self._vertex_prop_keys.clear()
self._edge_prop_keys.clear()
self._edge_count = 0
# clean data and index
self._vs.clear()
self._oes.clear()
self._ies.clear()
def graphviz(self, name="g"):
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
g = nx.MultiDiGraph()
for vertex in self.vertices():
g.add_node(vertex.vid)
for edge in self.edges():
triplet = edge.triplet()
g.add_edge(triplet[0], triplet[2], label=triplet[1])
digraph = nx.nx_agraph.to_agraph(g).to_string()
digraph = digraph.replace('digraph ""', f"digraph {name}")
digraph = re.sub(r"key=\d+,?\s*", "", digraph)
return digraph