mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 05:23:37 +00:00
Co-authored-by: Florian <fanzhidongyzby@163.com> Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
496 lines
14 KiB
Python
496 lines
14 KiB
Python
"""Graph definition."""
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
|
|
|
|
import networkx as nx
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Direction(Enum):
|
|
"""Direction class."""
|
|
|
|
OUT = 0
|
|
IN = 1
|
|
BOTH = 2
|
|
|
|
|
|
class Elem(ABC):
|
|
"""Elem class."""
|
|
|
|
def __init__(self, name: Optional[str] = None):
|
|
"""Initialize Elem."""
|
|
self._name = name
|
|
self._props: Dict[str, Any] = {}
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Return the edge label."""
|
|
return self._name or ""
|
|
|
|
@property
|
|
def props(self) -> Dict[str, Any]:
|
|
"""Get all the properties of Elem."""
|
|
return self._props
|
|
|
|
def set_prop(self, key: str, value: Any):
|
|
"""Set a property of ELem."""
|
|
self._props[key] = value
|
|
|
|
def get_prop(self, key: str):
|
|
"""Get one of the properties of Elem."""
|
|
return self._props.get(key)
|
|
|
|
def del_prop(self, key: str):
|
|
"""Delete a property of ELem."""
|
|
self._props.pop(key, None)
|
|
|
|
def has_props(self, **props):
|
|
"""Check all key-value pairs exist."""
|
|
return all(self._props.get(k) == v for k, v in props.items())
|
|
|
|
@abstractmethod
|
|
def format(self) -> str:
|
|
"""Format properties into a string."""
|
|
if len(self._props) == 1:
|
|
return str(next(iter(self._props.values())))
|
|
|
|
formatted_props = [
|
|
f"{k}:{json.dumps(v, ensure_ascii=False)}" for k, v in self._props.items()
|
|
]
|
|
return f"{{{';'.join(formatted_props)}}}"
|
|
|
|
|
|
class Vertex(Elem):
|
|
"""Vertex class."""
|
|
|
|
def __init__(self, vid: str, name: Optional[str] = None, **props):
|
|
"""Initialize Vertex."""
|
|
super().__init__(name)
|
|
self._vid = vid
|
|
for k, v in props.items():
|
|
self.set_prop(k, v)
|
|
|
|
@property
|
|
def vid(self) -> str:
|
|
"""Return the vertex ID."""
|
|
return self._vid
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Return the vertex name."""
|
|
return super().name or self._vid
|
|
|
|
def format(self, concise: bool = False):
|
|
"""Format vertex into a string."""
|
|
name = self._name or self._vid
|
|
if concise:
|
|
return f"({name})"
|
|
|
|
if self._props:
|
|
return f"({name}:{super().format()})"
|
|
else:
|
|
return f"({name})"
|
|
|
|
def __str__(self):
|
|
"""Return the vertex ID as its string representation."""
|
|
return f"({self._vid})"
|
|
|
|
|
|
class IdVertex(Vertex):
|
|
"""IdVertex class."""
|
|
|
|
def __init__(self, vid: str):
|
|
"""Initialize Idvertex."""
|
|
super().__init__(vid)
|
|
|
|
|
|
class Edge(Elem):
|
|
"""Edge class."""
|
|
|
|
def __init__(self, sid: str, tid: str, name: str, **props):
|
|
"""Initialize Edge."""
|
|
assert name, "Edge name is required"
|
|
|
|
super().__init__(name)
|
|
self._sid = sid
|
|
self._tid = tid
|
|
for k, v in props.items():
|
|
self.set_prop(k, v)
|
|
|
|
@property
|
|
def sid(self) -> str:
|
|
"""Return the source vertex ID of the edge."""
|
|
return self._sid
|
|
|
|
@property
|
|
def tid(self) -> str:
|
|
"""Return the target vertex ID of the edge."""
|
|
return self._tid
|
|
|
|
def nid(self, vid):
|
|
"""Return neighbor id."""
|
|
if vid == self._sid:
|
|
return self._tid
|
|
elif vid == self._tid:
|
|
return self._sid
|
|
else:
|
|
raise ValueError(f"Get nid of {vid} on {self} failed")
|
|
|
|
def format(self):
|
|
"""Format the edge properties into a string."""
|
|
if self._props:
|
|
return f"-[{self._name}:{super().format()}]->"
|
|
else:
|
|
return f"-[{self._name}]->"
|
|
|
|
def triplet(self) -> Tuple[str, str, str]:
|
|
"""Return a triplet."""
|
|
return self.sid, self.name, self.tid
|
|
|
|
def __str__(self):
|
|
"""Return the edge '(sid)->(tid)'."""
|
|
return f"({self._sid})-[{self._name}]->({self._tid})"
|
|
|
|
|
|
class Graph(ABC):
|
|
"""Graph class."""
|
|
|
|
@abstractmethod
|
|
def upsert_vertex(self, vertex: Vertex):
|
|
"""Add a vertex."""
|
|
|
|
@abstractmethod
|
|
def append_edge(self, edge: Edge):
|
|
"""Add an edge."""
|
|
|
|
@abstractmethod
|
|
def has_vertex(self, vid: str) -> bool:
|
|
"""Check vertex exists."""
|
|
|
|
@abstractmethod
|
|
def get_vertex(self, vid: str) -> Vertex:
|
|
"""Get a vertex."""
|
|
|
|
@abstractmethod
|
|
def get_neighbor_edges(
|
|
self,
|
|
vid: str,
|
|
direction: Direction = Direction.OUT,
|
|
limit: Optional[int] = None,
|
|
) -> Iterator[Edge]:
|
|
"""Get neighbor edges."""
|
|
|
|
@abstractmethod
|
|
def vertices(self) -> Iterator[Vertex]:
|
|
"""Get vertex iterator."""
|
|
|
|
@abstractmethod
|
|
def edges(self) -> Iterator[Edge]:
|
|
"""Get edge iterator."""
|
|
|
|
@abstractmethod
|
|
def del_vertices(self, *vids: str):
|
|
"""Delete vertices and their neighbor edges."""
|
|
|
|
@abstractmethod
|
|
def del_edges(self, sid: str, tid: str, name: str, **props):
|
|
"""Delete edges(sid -[name]-> tid) matches props."""
|
|
|
|
@abstractmethod
|
|
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
|
|
"""Delete neighbor edges."""
|
|
|
|
@abstractmethod
|
|
def search(
|
|
self,
|
|
vids: List[str],
|
|
direct: Direction = Direction.OUT,
|
|
depth: Optional[int] = None,
|
|
fan: Optional[int] = None,
|
|
limit: Optional[int] = None,
|
|
) -> "Graph":
|
|
"""Search on graph."""
|
|
|
|
@abstractmethod
|
|
def schema(self) -> Dict[str, Any]:
|
|
"""Get schema."""
|
|
|
|
@abstractmethod
|
|
def format(self) -> str:
|
|
"""Format graph data to string."""
|
|
|
|
@abstractmethod
|
|
def truncate(self):
|
|
"""Truncate graph."""
|
|
|
|
|
|
class MemoryGraph(Graph):
|
|
"""Graph class."""
|
|
|
|
def __init__(self):
|
|
"""Initialize MemoryGraph with vertex label and edge label."""
|
|
# metadata
|
|
self._vertex_prop_keys = set()
|
|
self._edge_prop_keys = set()
|
|
self._edge_count = 0
|
|
|
|
# init vertices, out edges, in edges index
|
|
self._vs: Any = defaultdict()
|
|
self._oes: Any = defaultdict(lambda: defaultdict(set))
|
|
self._ies: Any = defaultdict(lambda: defaultdict(set))
|
|
|
|
@property
|
|
def vertex_count(self):
|
|
"""Return the number of vertices in the graph."""
|
|
return len(self._vs)
|
|
|
|
@property
|
|
def edge_count(self):
|
|
"""Return the count of edges in the graph."""
|
|
return self._edge_count
|
|
|
|
def upsert_vertex(self, vertex: Vertex):
|
|
"""Insert or update a vertex based on its ID."""
|
|
if vertex.vid in self._vs:
|
|
if isinstance(self._vs[vertex.vid], IdVertex):
|
|
self._vs[vertex.vid] = vertex
|
|
else:
|
|
self._vs[vertex.vid].props.update(vertex.props)
|
|
else:
|
|
self._vs[vertex.vid] = vertex
|
|
|
|
# update metadata
|
|
self._vertex_prop_keys.update(vertex.props.keys())
|
|
|
|
def append_edge(self, edge: Edge):
|
|
"""Append an edge if it doesn't exist; requires edge label."""
|
|
sid = edge.sid
|
|
tid = edge.tid
|
|
|
|
if edge in self._oes[sid][tid]:
|
|
return False
|
|
|
|
# init vertex index
|
|
self._vs.setdefault(sid, IdVertex(sid))
|
|
self._vs.setdefault(tid, IdVertex(tid))
|
|
|
|
# update edge index
|
|
self._oes[sid][tid].add(edge)
|
|
self._ies[tid][sid].add(edge)
|
|
|
|
# update metadata
|
|
self._edge_prop_keys.update(edge.props.keys())
|
|
self._edge_count += 1
|
|
return True
|
|
|
|
def has_vertex(self, vid: str) -> bool:
|
|
"""Retrieve a vertex by ID."""
|
|
return vid in self._vs
|
|
|
|
def get_vertex(self, vid: str) -> Vertex:
|
|
"""Retrieve a vertex by ID."""
|
|
return self._vs[vid]
|
|
|
|
def get_neighbor_edges(
|
|
self,
|
|
vid: str,
|
|
direction: Direction = Direction.OUT,
|
|
limit: Optional[int] = None,
|
|
) -> Iterator[Edge]:
|
|
"""Get edges connected to a vertex by direction."""
|
|
if direction == Direction.OUT:
|
|
es = (e for es in self._oes[vid].values() for e in es)
|
|
|
|
elif direction == Direction.IN:
|
|
es = iter(e for es in self._ies[vid].values() for e in es)
|
|
|
|
elif direction == Direction.BOTH:
|
|
oes = (e for es in self._oes[vid].values() for e in es)
|
|
ies = (e for es in self._ies[vid].values() for e in es)
|
|
|
|
# merge
|
|
tuples = itertools.zip_longest(oes, ies)
|
|
es = (e for t in tuples for e in t if e is not None)
|
|
|
|
# distinct
|
|
seen = set()
|
|
|
|
# es = (e for e in es if e not in seen and not seen.add(e))
|
|
def unique_elements(elements):
|
|
for element in elements:
|
|
if element not in seen:
|
|
seen.add(element)
|
|
yield element
|
|
|
|
es = unique_elements(es)
|
|
else:
|
|
raise ValueError(f"Invalid direction: {direction}")
|
|
|
|
return itertools.islice(es, limit) if limit else es
|
|
|
|
def vertices(self) -> Iterator[Vertex]:
|
|
"""Return vertices."""
|
|
return iter(self._vs.values())
|
|
|
|
def edges(self) -> Iterator[Edge]:
|
|
"""Return edges."""
|
|
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
|
|
|
|
def del_vertices(self, *vids: str):
|
|
"""Delete specified vertices."""
|
|
for vid in vids:
|
|
self.del_neighbor_edges(vid, Direction.BOTH)
|
|
self._vs.pop(vid, None)
|
|
|
|
def del_edges(self, sid: str, tid: str, name: str, **props):
|
|
"""Delete edges."""
|
|
old_edge_cnt = len(self._oes[sid][tid])
|
|
|
|
def remove_matches(es):
|
|
return set(
|
|
filter(
|
|
lambda e: not (
|
|
(name == e.name if name else True) and e.has_props(**props)
|
|
),
|
|
es,
|
|
)
|
|
)
|
|
|
|
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
|
|
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
|
|
|
|
self._edge_count -= old_edge_cnt - len(self._oes[sid][tid])
|
|
|
|
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
|
|
"""Delete all neighbor edges."""
|
|
|
|
def del_index(idx, i_idx):
|
|
for nid in idx[vid].keys():
|
|
self._edge_count -= len(i_idx[nid][vid])
|
|
i_idx[nid].pop(vid, None)
|
|
idx.pop(vid, None)
|
|
|
|
if direction in [Direction.OUT, Direction.BOTH]:
|
|
del_index(self._oes, self._ies)
|
|
|
|
if direction in [Direction.IN, Direction.BOTH]:
|
|
del_index(self._ies, self._oes)
|
|
|
|
def search(
|
|
self,
|
|
vids: List[str],
|
|
direct: Direction = Direction.OUT,
|
|
depth: Optional[int] = None,
|
|
fan: Optional[int] = None,
|
|
limit: Optional[int] = None,
|
|
) -> "MemoryGraph":
|
|
"""Search the graph from a vertex with specified parameters."""
|
|
subgraph = MemoryGraph()
|
|
|
|
for vid in vids:
|
|
self.__search(vid, direct, depth, fan, limit, 0, set(), subgraph)
|
|
|
|
return subgraph
|
|
|
|
def __search(
|
|
self,
|
|
vid: str,
|
|
direct: Direction,
|
|
depth: Optional[int],
|
|
fan: Optional[int],
|
|
limit: Optional[int],
|
|
_depth: int,
|
|
_visited: Set,
|
|
_subgraph: "MemoryGraph",
|
|
):
|
|
if vid in _visited or depth and _depth >= depth:
|
|
return
|
|
|
|
# visit vertex
|
|
if not self.has_vertex(vid):
|
|
return
|
|
_subgraph.upsert_vertex(self.get_vertex(vid))
|
|
_visited.add(vid)
|
|
|
|
# visit edges
|
|
nids = set()
|
|
for edge in self.get_neighbor_edges(vid, direct, fan):
|
|
if limit and _subgraph.edge_count >= limit:
|
|
return
|
|
|
|
# append edge success then visit new vertex
|
|
if _subgraph.append_edge(edge):
|
|
nid = edge.nid(vid)
|
|
if nid not in _visited:
|
|
nids.add(nid)
|
|
|
|
# next hop
|
|
for nid in nids:
|
|
self.__search(
|
|
nid, direct, depth, fan, limit, _depth + 1, _visited, _subgraph
|
|
)
|
|
|
|
def schema(self) -> Dict[str, Any]:
|
|
"""Return schema."""
|
|
return {
|
|
"schema": [
|
|
{
|
|
"type": "VERTEX",
|
|
"properties": [{"name": k} for k in self._vertex_prop_keys],
|
|
},
|
|
{
|
|
"type": "EDGE",
|
|
"properties": [{"name": k} for k in self._edge_prop_keys],
|
|
},
|
|
]
|
|
}
|
|
|
|
def format(self) -> str:
|
|
"""Format graph to string."""
|
|
vs_str = "\n".join(v.format() for v in self.vertices())
|
|
es_str = "\n".join(
|
|
f"{self.get_vertex(e.sid).format(concise=True)}"
|
|
f"{e.format()}"
|
|
f"{self.get_vertex(e.tid).format(concise=True)}"
|
|
for e in self.edges()
|
|
)
|
|
return (
|
|
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
|
|
if (vs_str or es_str)
|
|
else ""
|
|
)
|
|
|
|
def truncate(self):
|
|
"""Truncate graph."""
|
|
# clean metadata
|
|
self._vertex_prop_keys.clear()
|
|
self._edge_prop_keys.clear()
|
|
self._edge_count = 0
|
|
|
|
# clean data and index
|
|
self._vs.clear()
|
|
self._oes.clear()
|
|
self._ies.clear()
|
|
|
|
def graphviz(self, name="g"):
|
|
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
|
|
g = nx.MultiDiGraph()
|
|
for vertex in self.vertices():
|
|
g.add_node(vertex.vid)
|
|
|
|
for edge in self.edges():
|
|
triplet = edge.triplet()
|
|
g.add_edge(triplet[0], triplet[2], label=triplet[1])
|
|
|
|
digraph = nx.nx_agraph.to_agraph(g).to_string()
|
|
digraph = digraph.replace('digraph ""', f"digraph {name}")
|
|
digraph = re.sub(r"key=\d+,?\s*", "", digraph)
|
|
return digraph
|