mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-21 17:37:52 +00:00
feat: add GraphRAG framework and integrate TuGraph (#1506)
Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
37
dbgpt/storage/graph_store/__init__.py
Normal file
37
dbgpt/storage/graph_store/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Graph Store Module."""
|
||||
from typing import Tuple, Type
|
||||
|
||||
|
||||
def _import_memgraph() -> Tuple[Type, Type]:
|
||||
from dbgpt.storage.graph_store.memgraph_store import (
|
||||
MemoryGraphStore,
|
||||
MemoryGraphStoreConfig,
|
||||
)
|
||||
|
||||
return MemoryGraphStore, MemoryGraphStoreConfig
|
||||
|
||||
|
||||
def _import_tugraph() -> Tuple[Type, Type]:
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig
|
||||
|
||||
return TuGraphStore, TuGraphStoreConfig
|
||||
|
||||
|
||||
def _import_neo4j() -> Tuple[Type, Type]:
|
||||
from dbgpt.storage.graph_store.neo4j_store import Neo4jStore, Neo4jStoreConfig
|
||||
|
||||
return Neo4jStore, Neo4jStoreConfig
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Tuple[Type, Type]:
|
||||
if name == "Memory":
|
||||
return _import_memgraph()
|
||||
elif name == "TuGraph":
|
||||
return _import_tugraph()
|
||||
elif name == "Neo4j":
|
||||
return _import_neo4j()
|
||||
else:
|
||||
raise AttributeError(f"Could not find: {name}")
|
||||
|
||||
|
||||
__all__ = ["Memory", "TuGraph", "Neo4j"]
|
68
dbgpt/storage/graph_store/base.py
Normal file
68
dbgpt/storage/graph_store/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Graph store base class."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.storage.graph_store.graph import Direction, Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphStoreConfig(BaseModel):
|
||||
"""Graph store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
name: str = Field(
|
||||
default="dbgpt_collection",
|
||||
description="The name of graph store, inherit from index store.",
|
||||
)
|
||||
embedding_fn: Optional[Embeddings] = Field(
|
||||
default=None,
|
||||
description="The embedding function of graph store, optional.",
|
||||
)
|
||||
|
||||
|
||||
class GraphStoreBase(ABC):
|
||||
"""Graph store base class."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Add triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
|
||||
@abstractmethod
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
|
||||
@abstractmethod
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Graph:
|
||||
"""Explore on graph."""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query."""
|
42
dbgpt/storage/graph_store/factory.py
Normal file
42
dbgpt/storage/graph_store/factory.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Connector for vector store."""
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
from dbgpt.storage import graph_store
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphStoreFactory:
|
||||
"""Factory for graph store."""
|
||||
|
||||
@staticmethod
|
||||
def create(graph_store_type: str, graph_store_configure=None) -> GraphStoreBase:
|
||||
"""Create a GraphStore instance.
|
||||
|
||||
Args:
|
||||
- graph_store_type: graph store type Memory, TuGraph, Neo4j
|
||||
- graph_store_config: graph store config
|
||||
"""
|
||||
store_cls, cfg_cls = GraphStoreFactory.__find_type(graph_store_type)
|
||||
|
||||
try:
|
||||
config = cfg_cls()
|
||||
if graph_store_configure:
|
||||
graph_store_configure(config)
|
||||
return store_cls(config)
|
||||
except Exception as e:
|
||||
logger.error("create graph store failed: %s", e)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def __find_type(graph_store_type: str) -> Tuple[Type, Type]:
|
||||
for t in graph_store.__all__:
|
||||
if t.lower() == graph_store_type.lower():
|
||||
store_cls, cfg_cls = getattr(graph_store, t)
|
||||
if issubclass(store_cls, GraphStoreBase) and issubclass(
|
||||
cfg_cls, GraphStoreConfig
|
||||
):
|
||||
return store_cls, cfg_cls
|
||||
raise Exception(f"Graph store {graph_store_type} not supported")
|
477
dbgpt/storage/graph_store/graph.py
Normal file
477
dbgpt/storage/graph_store/graph.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""Graph store base class."""
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Direction(Enum):
|
||||
"""Direction class."""
|
||||
|
||||
OUT = 0
|
||||
IN = 1
|
||||
BOTH = 2
|
||||
|
||||
|
||||
class Elem(ABC):
|
||||
"""Elem class."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Elem."""
|
||||
self._props = {}
|
||||
|
||||
@property
|
||||
def props(self) -> Dict[str, Any]:
|
||||
"""Get all the properties of Elem."""
|
||||
return self._props
|
||||
|
||||
def set_prop(self, key: str, value: Any):
|
||||
"""Set a property of ELem."""
|
||||
self._props[key] = value
|
||||
|
||||
def get_prop(self, key: str):
|
||||
"""Get one of the properties of Elem."""
|
||||
return self._props.get(key)
|
||||
|
||||
def del_prop(self, key: str):
|
||||
"""Delete a property of ELem."""
|
||||
self._props.pop(key, None)
|
||||
|
||||
def has_props(self, **props):
|
||||
"""Check if the element has the specified properties with the given values."""
|
||||
return all(self._props.get(k) == v for k, v in props.items())
|
||||
|
||||
@abstractmethod
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
"""Format properties into a string."""
|
||||
formatted_props = [
|
||||
f"{k}:{json.dumps(v)}" for k, v in self._props.items() if k != label_key
|
||||
]
|
||||
return f"{{{';'.join(formatted_props)}}}"
|
||||
|
||||
|
||||
class Vertex(Elem):
|
||||
"""Vertex class."""
|
||||
|
||||
def __init__(self, vid: str, **props):
|
||||
"""Initialize Vertex."""
|
||||
super().__init__()
|
||||
self._vid = vid
|
||||
for k, v in props.items():
|
||||
self.set_prop(k, v)
|
||||
|
||||
@property
|
||||
def vid(self) -> str:
|
||||
"""Return the vertex ID."""
|
||||
return self._vid
|
||||
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
"""Format vertex properties into a string."""
|
||||
label = self.get_prop(label_key) if label_key else self._vid
|
||||
props_str = super().format(label_key)
|
||||
if props_str == "{}":
|
||||
return f"({label})"
|
||||
else:
|
||||
return f"({label}:{props_str})"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the vertex ID as its string representation."""
|
||||
return f"({self._vid})"
|
||||
|
||||
|
||||
class Edge(Elem):
|
||||
"""Edge class."""
|
||||
|
||||
def __init__(self, sid: str, tid: str, **props):
|
||||
"""Initialize Edge."""
|
||||
super().__init__()
|
||||
self._sid = sid
|
||||
self._tid = tid
|
||||
for k, v in props.items():
|
||||
self.set_prop(k, v)
|
||||
|
||||
@property
|
||||
def sid(self) -> str:
|
||||
"""Return the source vertex ID of the edge."""
|
||||
return self._sid
|
||||
|
||||
@property
|
||||
def tid(self) -> str:
|
||||
"""Return the target vertex ID of the edge."""
|
||||
return self._tid
|
||||
|
||||
def nid(self, vid):
|
||||
"""Return neighbor id."""
|
||||
if vid == self._sid:
|
||||
return self._tid
|
||||
elif vid == self._tid:
|
||||
return self._sid
|
||||
else:
|
||||
raise ValueError(f"Get nid of {vid} on {self} failed")
|
||||
|
||||
def format(self, label_key: Optional[str] = None):
|
||||
"""Format the edge properties into a string."""
|
||||
label = self.get_prop(label_key) if label_key else ""
|
||||
props_str = super().format(label_key)
|
||||
if props_str == "{}":
|
||||
return f"-[{label}]->" if label else "->"
|
||||
else:
|
||||
return f"-[{label}:{props_str}]->" if label else f"-[{props_str}]->"
|
||||
|
||||
def triplet(self, label_key: str) -> Tuple[str, str, str]:
|
||||
"""Return a triplet."""
|
||||
assert label_key, "label key is needed"
|
||||
return self._sid, str(self.get_prop(label_key)), self._tid
|
||||
|
||||
def __str__(self):
|
||||
"""Return the edge '(sid)->(tid)'."""
|
||||
return f"({self._sid})->({self._tid})"
|
||||
|
||||
|
||||
class Graph(ABC):
|
||||
"""Graph class."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_vertex(self, vertex: Vertex):
|
||||
"""Add a vertex."""
|
||||
|
||||
@abstractmethod
|
||||
def append_edge(self, edge: Edge):
|
||||
"""Add an edge."""
|
||||
|
||||
@abstractmethod
|
||||
def has_vertex(self, vid: str) -> bool:
|
||||
"""Check vertex exists."""
|
||||
|
||||
@abstractmethod
|
||||
def get_vertex(self, vid: str) -> Vertex:
|
||||
"""Get a vertex."""
|
||||
|
||||
@abstractmethod
|
||||
def get_neighbor_edges(
|
||||
self,
|
||||
vid: str,
|
||||
direction: Direction = Direction.OUT,
|
||||
limit: Optional[int] = None,
|
||||
) -> Iterator[Edge]:
|
||||
"""Get neighbor edges."""
|
||||
|
||||
@abstractmethod
|
||||
def vertices(self) -> Iterator[Vertex]:
|
||||
"""Get vertex iterator."""
|
||||
|
||||
@abstractmethod
|
||||
def edges(self) -> Iterator[Edge]:
|
||||
"""Get edge iterator."""
|
||||
|
||||
@abstractmethod
|
||||
def del_vertices(self, *vids: str):
|
||||
"""Delete vertices and their neighbor edges."""
|
||||
|
||||
@abstractmethod
|
||||
def del_edges(self, sid: str, tid: str, **props):
|
||||
"""Delete edges(sid -> tid) matches props."""
|
||||
|
||||
@abstractmethod
|
||||
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
|
||||
"""Delete neighbor edges."""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
vids: List[str],
|
||||
direct: Direction = Direction.OUT,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> "Graph":
|
||||
"""Search on graph."""
|
||||
|
||||
@abstractmethod
|
||||
def schema(self) -> Dict[str, Any]:
|
||||
"""Get schema."""
|
||||
|
||||
@abstractmethod
|
||||
def format(self) -> str:
|
||||
"""Format graph data to string."""
|
||||
|
||||
|
||||
class MemoryGraph(Graph):
|
||||
"""Graph class."""
|
||||
|
||||
def __init__(self, vertex_label: Optional[str] = None, edge_label: str = "label"):
|
||||
"""Initialize MemoryGraph with vertex label and edge label."""
|
||||
assert edge_label, "Edge label is needed"
|
||||
|
||||
# metadata
|
||||
self._vertex_label = vertex_label
|
||||
self._edge_label = edge_label
|
||||
self._vertex_prop_keys = {vertex_label} if vertex_label else set()
|
||||
self._edge_prop_keys = {edge_label}
|
||||
self._edge_count = 0
|
||||
|
||||
# init vertices, out edges, in edges index
|
||||
self._vs: Any = defaultdict()
|
||||
self._oes: Any = defaultdict(lambda: defaultdict(set))
|
||||
self._ies: Any = defaultdict(lambda: defaultdict(set))
|
||||
|
||||
@property
|
||||
def vertex_label(self):
|
||||
"""Return the label for vertices."""
|
||||
return self._vertex_label
|
||||
|
||||
@property
|
||||
def edge_label(self):
|
||||
"""Return the label for edges."""
|
||||
return self._edge_label
|
||||
|
||||
@property
|
||||
def vertex_prop_keys(self):
|
||||
"""Return a set of property keys for vertices."""
|
||||
return self._vertex_prop_keys
|
||||
|
||||
@property
|
||||
def edge_prop_keys(self):
|
||||
"""Return a set of property keys for edges."""
|
||||
return self._edge_prop_keys
|
||||
|
||||
@property
|
||||
def vertex_count(self):
|
||||
"""Return the number of vertices in the graph."""
|
||||
return len(self._vs)
|
||||
|
||||
@property
|
||||
def edge_count(self):
|
||||
"""Return the count of edges in the graph."""
|
||||
return self._edge_count
|
||||
|
||||
def upsert_vertex(self, vertex: Vertex):
|
||||
"""Insert or update a vertex based on its ID."""
|
||||
if vertex.vid in self._vs:
|
||||
self._vs[vertex.vid].props.update(vertex.props)
|
||||
else:
|
||||
self._vs[vertex.vid] = vertex
|
||||
|
||||
# update metadata
|
||||
self._vertex_prop_keys.update(vertex.props.keys())
|
||||
|
||||
def append_edge(self, edge: Edge):
|
||||
"""Append an edge if it doesn't exist; requires edge label."""
|
||||
if self.edge_label not in edge.props.keys():
|
||||
raise ValueError(f"Edge prop '{self.edge_label}' is needed")
|
||||
|
||||
sid = edge.sid
|
||||
tid = edge.tid
|
||||
|
||||
if edge in self._oes[sid][tid]:
|
||||
return False
|
||||
|
||||
# init vertex index
|
||||
self._vs.setdefault(sid, Vertex(sid))
|
||||
self._vs.setdefault(tid, Vertex(tid))
|
||||
|
||||
# update edge index
|
||||
self._oes[sid][tid].add(edge)
|
||||
self._ies[tid][sid].add(edge)
|
||||
|
||||
# update metadata
|
||||
self._edge_prop_keys.update(edge.props.keys())
|
||||
self._edge_count += 1
|
||||
return True
|
||||
|
||||
def has_vertex(self, vid: str) -> bool:
|
||||
"""Retrieve a vertex by ID."""
|
||||
return vid in self._vs
|
||||
|
||||
def get_vertex(self, vid: str) -> Vertex:
|
||||
"""Retrieve a vertex by ID."""
|
||||
return self._vs[vid]
|
||||
|
||||
def get_neighbor_edges(
|
||||
self,
|
||||
vid: str,
|
||||
direction: Direction = Direction.OUT,
|
||||
limit: Optional[int] = None,
|
||||
) -> Iterator[Edge]:
|
||||
"""Get edges connected to a vertex by direction."""
|
||||
if direction == Direction.OUT:
|
||||
es = (e for es in self._oes[vid].values() for e in es)
|
||||
|
||||
elif direction == Direction.IN:
|
||||
es = iter(e for es in self._ies[vid].values() for e in es)
|
||||
|
||||
elif direction == Direction.BOTH:
|
||||
oes = (e for es in self._oes[vid].values() for e in es)
|
||||
ies = (e for es in self._ies[vid].values() for e in es)
|
||||
|
||||
# merge
|
||||
tuples = itertools.zip_longest(oes, ies)
|
||||
es = (e for t in tuples for e in t if e is not None)
|
||||
|
||||
# distinct
|
||||
seen = set()
|
||||
|
||||
# es = (e for e in es if e not in seen and not seen.add(e))
|
||||
def unique_elements(elements):
|
||||
for element in elements:
|
||||
if element not in seen:
|
||||
seen.add(element)
|
||||
yield element
|
||||
|
||||
es = unique_elements(es)
|
||||
else:
|
||||
raise ValueError(f"Invalid direction: {direction}")
|
||||
|
||||
return itertools.islice(es, limit) if limit else es
|
||||
|
||||
def vertices(self) -> Iterator[Vertex]:
|
||||
"""Return vertices."""
|
||||
return iter(self._vs.values())
|
||||
|
||||
def edges(self) -> Iterator[Edge]:
|
||||
"""Return edges."""
|
||||
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
|
||||
|
||||
def del_vertices(self, *vids: str):
|
||||
"""Delete specified vertices."""
|
||||
for vid in vids:
|
||||
self.del_neighbor_edges(vid, Direction.BOTH)
|
||||
self._vs.pop(vid, None)
|
||||
|
||||
def del_edges(self, sid: str, tid: str, **props):
|
||||
"""Delete edges."""
|
||||
old_edge_cnt = len(self._oes[sid][tid])
|
||||
|
||||
if not props:
|
||||
self._edge_count -= old_edge_cnt
|
||||
self._oes[sid].pop(tid, None)
|
||||
self._ies[tid].pop(sid, None)
|
||||
return
|
||||
|
||||
def remove_matches(es):
|
||||
return set(filter(lambda e: not e.has_props(**props), es))
|
||||
|
||||
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
|
||||
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
|
||||
|
||||
self._edge_count -= old_edge_cnt - len(self._oes[sid][tid])
|
||||
|
||||
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
|
||||
"""Delete all neighbor edges."""
|
||||
|
||||
def del_index(idx, i_idx):
|
||||
for nid in idx[vid].keys():
|
||||
self._edge_count -= len(i_idx[nid][vid])
|
||||
i_idx[nid].pop(vid, None)
|
||||
idx.pop(vid, None)
|
||||
|
||||
if direction in [Direction.OUT, Direction.BOTH]:
|
||||
del_index(self._oes, self._ies)
|
||||
|
||||
if direction in [Direction.IN, Direction.BOTH]:
|
||||
del_index(self._ies, self._oes)
|
||||
|
||||
def search(
|
||||
self,
|
||||
vids: List[str],
|
||||
direct: Direction = Direction.OUT,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> "MemoryGraph":
|
||||
"""Search the graph from a vertex with specified parameters."""
|
||||
subgraph = MemoryGraph()
|
||||
|
||||
for vid in vids:
|
||||
self.__search(vid, direct, depth, fan, limit, 0, set(), subgraph)
|
||||
|
||||
return subgraph
|
||||
|
||||
def __search(
|
||||
self,
|
||||
vid: str,
|
||||
direct: Direction,
|
||||
depth: Optional[int],
|
||||
fan: Optional[int],
|
||||
limit: Optional[int],
|
||||
_depth: int,
|
||||
_visited: Set,
|
||||
_subgraph: "MemoryGraph",
|
||||
):
|
||||
if vid in _visited or depth and _depth >= depth:
|
||||
return
|
||||
|
||||
# visit vertex
|
||||
if not self.has_vertex(vid):
|
||||
return
|
||||
_subgraph.upsert_vertex(self.get_vertex(vid))
|
||||
_visited.add(vid)
|
||||
|
||||
# visit edges
|
||||
nids = set()
|
||||
for edge in self.get_neighbor_edges(vid, direct, fan):
|
||||
if limit and _subgraph.edge_count >= limit:
|
||||
return
|
||||
|
||||
# append edge success then visit new vertex
|
||||
if _subgraph.append_edge(edge):
|
||||
nid = edge.nid(vid)
|
||||
if nid not in _visited:
|
||||
nids.add(nid)
|
||||
|
||||
# next hop
|
||||
for nid in nids:
|
||||
self.__search(
|
||||
nid, direct, depth, fan, limit, _depth + 1, _visited, _subgraph
|
||||
)
|
||||
|
||||
def schema(self) -> Dict[str, Any]:
|
||||
"""Return schema."""
|
||||
return {
|
||||
"schema": [
|
||||
{
|
||||
"type": "VERTEX",
|
||||
"label": f"{self._vertex_label}",
|
||||
"properties": [{"name": k} for k in self._vertex_prop_keys],
|
||||
},
|
||||
{
|
||||
"type": "EDGE",
|
||||
"label": f"{self._edge_label}",
|
||||
"properties": [{"name": k} for k in self._edge_prop_keys],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def format(self) -> str:
|
||||
"""Format graph to string."""
|
||||
vs_str = "\n".join(v.format(self.vertex_label) for v in self.vertices())
|
||||
es_str = "\n".join(
|
||||
f"{self.get_vertex(e.sid).format(self.vertex_label)}"
|
||||
f"{e.format(self.edge_label)}"
|
||||
f"{self.get_vertex(e.tid).format(self.vertex_label)}"
|
||||
for e in self.edges()
|
||||
)
|
||||
return f"Vertices:\n{vs_str}\n\nEdges:\n{es_str}"
|
||||
|
||||
def graphviz(self, name="g"):
|
||||
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
|
||||
g = nx.MultiDiGraph()
|
||||
for vertex in self.vertices():
|
||||
g.add_node(vertex.vid)
|
||||
|
||||
for edge in self.edges():
|
||||
triplet = edge.triplet(self.edge_label)
|
||||
g.add_edge(triplet[0], triplet[2], label=triplet[1])
|
||||
|
||||
digraph = nx.nx_agraph.to_agraph(g).to_string()
|
||||
digraph = digraph.replace('digraph ""', f"digraph {name}")
|
||||
digraph = re.sub(r"key=\d+,?\s*", "", digraph)
|
||||
return digraph
|
81
dbgpt/storage/graph_store/memgraph_store.py
Normal file
81
dbgpt/storage/graph_store/memgraph_store.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Graph store base class."""
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryGraphStoreConfig(GraphStoreConfig):
|
||||
"""Memory graph store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
edge_name_key: str = Field(
|
||||
default="label",
|
||||
description="The label of edge name, `label` by default.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryGraphStore(GraphStoreBase):
|
||||
"""Memory graph store."""
|
||||
|
||||
def __init__(self, graph_store_config: MemoryGraphStoreConfig):
|
||||
"""Initialize MemoryGraphStore with a memory graph."""
|
||||
self._edge_name_key = graph_store_config.edge_name_key
|
||||
self._graph = MemoryGraph(edge_label=self._edge_name_key)
|
||||
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Insert a triplet into the graph."""
|
||||
self._graph.append_edge(Edge(sub, obj, **{self._edge_name_key: rel}))
|
||||
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Retrieve triplets originating from a subject."""
|
||||
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
|
||||
return [(e.get_prop(self._edge_name_key), e.tid) for e in subgraph.edges()]
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete a specific triplet from the graph."""
|
||||
self._graph.del_edges(sub, obj, **{self._edge_name_key: rel})
|
||||
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
self._graph = None
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Return the graph schema as a JSON string."""
|
||||
return json.dumps(self._graph.schema())
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
|
||||
"""Return self."""
|
||||
if not limit:
|
||||
return self._graph
|
||||
|
||||
subgraph = MemoryGraph()
|
||||
for count, edge in enumerate(self._graph.edges()):
|
||||
if count >= limit:
|
||||
break
|
||||
subgraph.upsert_vertex(self._graph.get_vertex(edge.sid))
|
||||
subgraph.upsert_vertex(self._graph.get_vertex(edge.tid))
|
||||
subgraph.append_edge(edge)
|
||||
count += 1
|
||||
return subgraph
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return self._graph.search(subs, direct, depth, fan, limit)
|
||||
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query on graph."""
|
||||
raise NotImplementedError("Query memory graph not allowed")
|
64
dbgpt/storage/graph_store/neo4j_store.py
Normal file
64
dbgpt/storage/graph_store/neo4j_store.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Neo4j vector store."""
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Graph, MemoryGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Neo4jStoreConfig(GraphStoreConfig):
|
||||
"""Neo4j store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class Neo4jStore(GraphStoreBase):
|
||||
"""Neo4j graph store."""
|
||||
|
||||
# todo: add neo4j implementation
|
||||
|
||||
def __init__(self, graph_store_config: Neo4jStoreConfig):
|
||||
"""Initialize the Neo4jStore with connection details."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Insert triplets."""
|
||||
pass
|
||||
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
return []
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete triplets."""
|
||||
pass
|
||||
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
pass
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
return ""
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
return MemoryGraph()
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Graph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return MemoryGraph()
|
||||
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query on graph."""
|
||||
return MemoryGraph()
|
239
dbgpt/storage/graph_store/tugraph_store.py
Normal file
239
dbgpt/storage/graph_store/tugraph_store.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""TuGraph vector store."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.datasource.conn_tugraph import TuGraphConnector
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphStoreConfig(GraphStoreConfig):
|
||||
"""TuGraph store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
host: str = Field(
|
||||
default="127.0.0.1",
|
||||
description="TuGraph host",
|
||||
)
|
||||
port: int = Field(
|
||||
default=7687,
|
||||
description="TuGraph port",
|
||||
)
|
||||
username: str = Field(
|
||||
default="admin",
|
||||
description="login username",
|
||||
)
|
||||
password: str = Field(
|
||||
default="123456",
|
||||
description="login password",
|
||||
)
|
||||
vertex_type: str = Field(
|
||||
default="entity",
|
||||
description="The type of graph vertex, `entity` by default.",
|
||||
)
|
||||
edge_type: str = Field(
|
||||
default="relation",
|
||||
description="The type of graph edge, `relation` by default.",
|
||||
)
|
||||
edge_name_key: str = Field(
|
||||
default="label",
|
||||
description="The label of edge name, `label` by default.",
|
||||
)
|
||||
|
||||
|
||||
class TuGraphStore(GraphStoreBase):
|
||||
"""TuGraph graph store."""
|
||||
|
||||
def __init__(self, config: TuGraphStoreConfig) -> None:
|
||||
"""Initialize the TuGraphStore with connection details."""
|
||||
self._host = os.getenv("TUGRAPH_HOST", "127.0.0.1") or config.host
|
||||
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
|
||||
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
|
||||
self._password = os.getenv("TUGRAPH_PASSWORD", "73@TuGraph") or config.password
|
||||
self._node_label = (
|
||||
os.getenv("TUGRAPH_VERTEX_TYPE", "entity") or config.vertex_type
|
||||
)
|
||||
self._edge_label = (
|
||||
os.getenv("TUGRAPH_EDGE_TYPE", "relation") or config.edge_type
|
||||
)
|
||||
self.edge_name_key = (
|
||||
os.getenv("TUGRAPH_EDGE_NAME_KEY", "label") or config.edge_name_key
|
||||
)
|
||||
self._graph_name = config.name
|
||||
self.conn = TuGraphConnector.from_uri_db(
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
user=self._username,
|
||||
pwd=self._password,
|
||||
db_name=config.name,
|
||||
)
|
||||
self.conn.create_graph(graph_name=config.name)
|
||||
|
||||
self._create_schema()
|
||||
|
||||
def _check_label(self, elem_type: str):
|
||||
result = self.conn.get_table_names()
|
||||
if elem_type == "vertex":
|
||||
return self._node_label in result["vertex_tables"]
|
||||
if elem_type == "edge":
|
||||
return self._edge_label in result["edge_tables"]
|
||||
|
||||
def _create_schema(self):
|
||||
if not self._check_label("vertex"):
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._node_label}', "
|
||||
f"'id', ['id',string,false])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
if not self._check_label("edge"):
|
||||
create_edge_gql = f"""CALL db.createLabel(
|
||||
'edge', '{self._edge_label}', '[["{self._node_label}",
|
||||
"{self._node_label}"]]', ["id",STRING,false])"""
|
||||
self.conn.run(create_edge_gql)
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
query = (
|
||||
f"MATCH (n1:{self._node_label})-[r]->(n2:{self._node_label}) "
|
||||
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
|
||||
)
|
||||
data = self.conn.run(query)
|
||||
return [(record["rel"], record["obj"]) for record in data]
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
|
||||
def escape_quotes(value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
return value.replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
subj_escaped = escape_quotes(subj)
|
||||
rel_escaped = escape_quotes(rel)
|
||||
obj_escaped = escape_quotes(obj)
|
||||
|
||||
subj_query = f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
|
||||
obj_query = f"MERGE (n1:{self._node_label} {{id:'{obj_escaped}'}})"
|
||||
rel_query = (
|
||||
f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
|
||||
f"-[r:{self._edge_label} {{id:'{rel_escaped}'}}]->"
|
||||
f"(n2:{self._node_label} {{id:'{obj_escaped}'}})"
|
||||
)
|
||||
self.conn.run(query=subj_query)
|
||||
self.conn.run(query=obj_query)
|
||||
self.conn.run(query=rel_query)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self.conn.delete_graph(self._graph_name)
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
del_query = (
|
||||
f"MATCH (n1:{self._node_label} {{id:'{sub}'}})"
|
||||
f"-[r:{self._edge_label} {{id:'{rel}'}}]->"
|
||||
f"(n2:{self._node_label} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
)
|
||||
self.conn.run(query=del_query)
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
query = "CALL dbms.graph.getGraphSchema()"
|
||||
data = self.conn.run(query=query)
|
||||
schema = data[0]["schema"]
|
||||
return schema
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
raise Exception("limit must be set")
|
||||
return self.query(f"MATCH (n)-[r]-(m) RETURN n,m,r LIMIT {limit}")
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
if fan is not None:
|
||||
raise ValueError("Fan functionality is not supported at this time.")
|
||||
else:
|
||||
depth_string = f"1..{depth}"
|
||||
if depth is None:
|
||||
depth_string = ".."
|
||||
|
||||
limit_string = f"LIMIT {limit}"
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
|
||||
query = (
|
||||
f"MATCH p=(n:{self._node_label})"
|
||||
f"-[r:{self._edge_label}*{depth_string}]-(m:{self._node_label}) "
|
||||
f"WHERE n.id IN {subs} RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query)
|
||||
|
||||
def query(self, query: str, **args) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
|
||||
def _format_paths(paths):
|
||||
formatted_paths = []
|
||||
for path in paths:
|
||||
formatted_path = []
|
||||
nodes = list(path["p"].nodes)
|
||||
rels = list(path["p"].relationships)
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append(nodes[i]._properties["id"])
|
||||
if i < len(rels):
|
||||
formatted_path.append(rels[i]._properties["id"])
|
||||
formatted_paths.append(formatted_path)
|
||||
return formatted_paths
|
||||
|
||||
def _format_query_data(data):
|
||||
node_ids_set = set()
|
||||
rels_set = set()
|
||||
from neo4j import graph
|
||||
|
||||
for record in data:
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
node_ids_set.add(node_id)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
rel_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
src_id = rel_nodes[0]._properties["id"]
|
||||
dst_id = rel_nodes[1]._properties["id"]
|
||||
rels_set.add((src_id, dst_id, prop_id))
|
||||
elif isinstance(value, graph.Path):
|
||||
formatted_paths = _format_paths(data)
|
||||
for path in formatted_paths:
|
||||
for i in range(0, len(path), 2):
|
||||
node_ids_set.add(path[i])
|
||||
if i + 2 < len(path):
|
||||
rels_set.add((path[i], path[i + 2], path[i + 1]))
|
||||
|
||||
nodes = [Vertex(node_id) for node_id in node_ids_set]
|
||||
rels = [
|
||||
Edge(src_id, dst_id, label=prop_id)
|
||||
for (src_id, dst_id, prop_id) in rels_set
|
||||
]
|
||||
return {"nodes": nodes, "edges": rels}
|
||||
|
||||
result = self.conn.run(query=query)
|
||||
graph = _format_query_data(result)
|
||||
mg = MemoryGraph()
|
||||
for vertex in graph["nodes"]:
|
||||
mg.upsert_vertex(vertex)
|
||||
for edge in graph["edges"]:
|
||||
mg.append_edge(edge)
|
||||
return mg
|
Reference in New Issue
Block a user