feat: add GraphRAG framework and integrate TuGraph (#1506)

Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Florian
2024-05-16 15:39:50 +08:00
committed by GitHub
parent 593e974405
commit a9087c3853
133 changed files with 10139 additions and 6631 deletions

View File

@@ -0,0 +1,37 @@
"""Graph Store Module."""
from typing import Tuple, Type
def _import_memgraph() -> Tuple[Type, Type]:
from dbgpt.storage.graph_store.memgraph_store import (
MemoryGraphStore,
MemoryGraphStoreConfig,
)
return MemoryGraphStore, MemoryGraphStoreConfig
def _import_tugraph() -> Tuple[Type, Type]:
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig
return TuGraphStore, TuGraphStoreConfig
def _import_neo4j() -> Tuple[Type, Type]:
from dbgpt.storage.graph_store.neo4j_store import Neo4jStore, Neo4jStoreConfig
return Neo4jStore, Neo4jStoreConfig
def __getattr__(name: str) -> Tuple[Type, Type]:
if name == "Memory":
return _import_memgraph()
elif name == "TuGraph":
return _import_tugraph()
elif name == "Neo4j":
return _import_neo4j()
else:
raise AttributeError(f"Could not find: {name}")
__all__ = ["Memory", "TuGraph", "Neo4j"]

View File

@@ -0,0 +1,68 @@
"""Graph store base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
from dbgpt.storage.graph_store.graph import Direction, Graph
logger = logging.getLogger(__name__)
class GraphStoreConfig(BaseModel):
"""Graph store config."""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
name: str = Field(
default="dbgpt_collection",
description="The name of graph store, inherit from index store.",
)
embedding_fn: Optional[Embeddings] = Field(
default=None,
description="The embedding function of graph store, optional.",
)
class GraphStoreBase(ABC):
"""Graph store base class."""
@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Add triplet."""
@abstractmethod
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""
@abstractmethod
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplet."""
@abstractmethod
def drop(self):
"""Drop graph."""
@abstractmethod
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""
@abstractmethod
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
@abstractmethod
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore on graph."""
@abstractmethod
def query(self, query: str, **args) -> Graph:
"""Execute a query."""

View File

@@ -0,0 +1,42 @@
"""Connector for vector store."""
import logging
from typing import Tuple, Type
from dbgpt.storage import graph_store
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
logger = logging.getLogger(__name__)
class GraphStoreFactory:
"""Factory for graph store."""
@staticmethod
def create(graph_store_type: str, graph_store_configure=None) -> GraphStoreBase:
"""Create a GraphStore instance.
Args:
- graph_store_type: graph store type Memory, TuGraph, Neo4j
- graph_store_config: graph store config
"""
store_cls, cfg_cls = GraphStoreFactory.__find_type(graph_store_type)
try:
config = cfg_cls()
if graph_store_configure:
graph_store_configure(config)
return store_cls(config)
except Exception as e:
logger.error("create graph store failed: %s", e)
raise e
@staticmethod
def __find_type(graph_store_type: str) -> Tuple[Type, Type]:
for t in graph_store.__all__:
if t.lower() == graph_store_type.lower():
store_cls, cfg_cls = getattr(graph_store, t)
if issubclass(store_cls, GraphStoreBase) and issubclass(
cfg_cls, GraphStoreConfig
):
return store_cls, cfg_cls
raise Exception(f"Graph store {graph_store_type} not supported")

View File

@@ -0,0 +1,477 @@
"""Graph store base class."""
import itertools
import json
import logging
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
import networkx as nx
logger = logging.getLogger(__name__)
class Direction(Enum):
"""Direction class."""
OUT = 0
IN = 1
BOTH = 2
class Elem(ABC):
"""Elem class."""
def __init__(self):
"""Initialize Elem."""
self._props = {}
@property
def props(self) -> Dict[str, Any]:
"""Get all the properties of Elem."""
return self._props
def set_prop(self, key: str, value: Any):
"""Set a property of ELem."""
self._props[key] = value
def get_prop(self, key: str):
"""Get one of the properties of Elem."""
return self._props.get(key)
def del_prop(self, key: str):
"""Delete a property of ELem."""
self._props.pop(key, None)
def has_props(self, **props):
"""Check if the element has the specified properties with the given values."""
return all(self._props.get(k) == v for k, v in props.items())
@abstractmethod
def format(self, label_key: Optional[str] = None):
"""Format properties into a string."""
formatted_props = [
f"{k}:{json.dumps(v)}" for k, v in self._props.items() if k != label_key
]
return f"{{{';'.join(formatted_props)}}}"
class Vertex(Elem):
"""Vertex class."""
def __init__(self, vid: str, **props):
"""Initialize Vertex."""
super().__init__()
self._vid = vid
for k, v in props.items():
self.set_prop(k, v)
@property
def vid(self) -> str:
"""Return the vertex ID."""
return self._vid
def format(self, label_key: Optional[str] = None):
"""Format vertex properties into a string."""
label = self.get_prop(label_key) if label_key else self._vid
props_str = super().format(label_key)
if props_str == "{}":
return f"({label})"
else:
return f"({label}:{props_str})"
def __str__(self):
"""Return the vertex ID as its string representation."""
return f"({self._vid})"
class Edge(Elem):
"""Edge class."""
def __init__(self, sid: str, tid: str, **props):
"""Initialize Edge."""
super().__init__()
self._sid = sid
self._tid = tid
for k, v in props.items():
self.set_prop(k, v)
@property
def sid(self) -> str:
"""Return the source vertex ID of the edge."""
return self._sid
@property
def tid(self) -> str:
"""Return the target vertex ID of the edge."""
return self._tid
def nid(self, vid):
"""Return neighbor id."""
if vid == self._sid:
return self._tid
elif vid == self._tid:
return self._sid
else:
raise ValueError(f"Get nid of {vid} on {self} failed")
def format(self, label_key: Optional[str] = None):
"""Format the edge properties into a string."""
label = self.get_prop(label_key) if label_key else ""
props_str = super().format(label_key)
if props_str == "{}":
return f"-[{label}]->" if label else "->"
else:
return f"-[{label}:{props_str}]->" if label else f"-[{props_str}]->"
def triplet(self, label_key: str) -> Tuple[str, str, str]:
"""Return a triplet."""
assert label_key, "label key is needed"
return self._sid, str(self.get_prop(label_key)), self._tid
def __str__(self):
"""Return the edge '(sid)->(tid)'."""
return f"({self._sid})->({self._tid})"
class Graph(ABC):
"""Graph class."""
@abstractmethod
def upsert_vertex(self, vertex: Vertex):
"""Add a vertex."""
@abstractmethod
def append_edge(self, edge: Edge):
"""Add an edge."""
@abstractmethod
def has_vertex(self, vid: str) -> bool:
"""Check vertex exists."""
@abstractmethod
def get_vertex(self, vid: str) -> Vertex:
"""Get a vertex."""
@abstractmethod
def get_neighbor_edges(
self,
vid: str,
direction: Direction = Direction.OUT,
limit: Optional[int] = None,
) -> Iterator[Edge]:
"""Get neighbor edges."""
@abstractmethod
def vertices(self) -> Iterator[Vertex]:
"""Get vertex iterator."""
@abstractmethod
def edges(self) -> Iterator[Edge]:
"""Get edge iterator."""
@abstractmethod
def del_vertices(self, *vids: str):
"""Delete vertices and their neighbor edges."""
@abstractmethod
def del_edges(self, sid: str, tid: str, **props):
"""Delete edges(sid -> tid) matches props."""
@abstractmethod
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
"""Delete neighbor edges."""
@abstractmethod
def search(
self,
vids: List[str],
direct: Direction = Direction.OUT,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> "Graph":
"""Search on graph."""
@abstractmethod
def schema(self) -> Dict[str, Any]:
"""Get schema."""
@abstractmethod
def format(self) -> str:
"""Format graph data to string."""
class MemoryGraph(Graph):
"""Graph class."""
def __init__(self, vertex_label: Optional[str] = None, edge_label: str = "label"):
"""Initialize MemoryGraph with vertex label and edge label."""
assert edge_label, "Edge label is needed"
# metadata
self._vertex_label = vertex_label
self._edge_label = edge_label
self._vertex_prop_keys = {vertex_label} if vertex_label else set()
self._edge_prop_keys = {edge_label}
self._edge_count = 0
# init vertices, out edges, in edges index
self._vs: Any = defaultdict()
self._oes: Any = defaultdict(lambda: defaultdict(set))
self._ies: Any = defaultdict(lambda: defaultdict(set))
@property
def vertex_label(self):
"""Return the label for vertices."""
return self._vertex_label
@property
def edge_label(self):
"""Return the label for edges."""
return self._edge_label
@property
def vertex_prop_keys(self):
"""Return a set of property keys for vertices."""
return self._vertex_prop_keys
@property
def edge_prop_keys(self):
"""Return a set of property keys for edges."""
return self._edge_prop_keys
@property
def vertex_count(self):
"""Return the number of vertices in the graph."""
return len(self._vs)
@property
def edge_count(self):
"""Return the count of edges in the graph."""
return self._edge_count
def upsert_vertex(self, vertex: Vertex):
"""Insert or update a vertex based on its ID."""
if vertex.vid in self._vs:
self._vs[vertex.vid].props.update(vertex.props)
else:
self._vs[vertex.vid] = vertex
# update metadata
self._vertex_prop_keys.update(vertex.props.keys())
def append_edge(self, edge: Edge):
"""Append an edge if it doesn't exist; requires edge label."""
if self.edge_label not in edge.props.keys():
raise ValueError(f"Edge prop '{self.edge_label}' is needed")
sid = edge.sid
tid = edge.tid
if edge in self._oes[sid][tid]:
return False
# init vertex index
self._vs.setdefault(sid, Vertex(sid))
self._vs.setdefault(tid, Vertex(tid))
# update edge index
self._oes[sid][tid].add(edge)
self._ies[tid][sid].add(edge)
# update metadata
self._edge_prop_keys.update(edge.props.keys())
self._edge_count += 1
return True
def has_vertex(self, vid: str) -> bool:
"""Retrieve a vertex by ID."""
return vid in self._vs
def get_vertex(self, vid: str) -> Vertex:
"""Retrieve a vertex by ID."""
return self._vs[vid]
def get_neighbor_edges(
self,
vid: str,
direction: Direction = Direction.OUT,
limit: Optional[int] = None,
) -> Iterator[Edge]:
"""Get edges connected to a vertex by direction."""
if direction == Direction.OUT:
es = (e for es in self._oes[vid].values() for e in es)
elif direction == Direction.IN:
es = iter(e for es in self._ies[vid].values() for e in es)
elif direction == Direction.BOTH:
oes = (e for es in self._oes[vid].values() for e in es)
ies = (e for es in self._ies[vid].values() for e in es)
# merge
tuples = itertools.zip_longest(oes, ies)
es = (e for t in tuples for e in t if e is not None)
# distinct
seen = set()
# es = (e for e in es if e not in seen and not seen.add(e))
def unique_elements(elements):
for element in elements:
if element not in seen:
seen.add(element)
yield element
es = unique_elements(es)
else:
raise ValueError(f"Invalid direction: {direction}")
return itertools.islice(es, limit) if limit else es
def vertices(self) -> Iterator[Vertex]:
"""Return vertices."""
return iter(self._vs.values())
def edges(self) -> Iterator[Edge]:
"""Return edges."""
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
def del_vertices(self, *vids: str):
"""Delete specified vertices."""
for vid in vids:
self.del_neighbor_edges(vid, Direction.BOTH)
self._vs.pop(vid, None)
def del_edges(self, sid: str, tid: str, **props):
"""Delete edges."""
old_edge_cnt = len(self._oes[sid][tid])
if not props:
self._edge_count -= old_edge_cnt
self._oes[sid].pop(tid, None)
self._ies[tid].pop(sid, None)
return
def remove_matches(es):
return set(filter(lambda e: not e.has_props(**props), es))
self._oes[sid][tid] = remove_matches(self._oes[sid][tid])
self._ies[tid][sid] = remove_matches(self._ies[tid][sid])
self._edge_count -= old_edge_cnt - len(self._oes[sid][tid])
def del_neighbor_edges(self, vid: str, direction: Direction = Direction.OUT):
"""Delete all neighbor edges."""
def del_index(idx, i_idx):
for nid in idx[vid].keys():
self._edge_count -= len(i_idx[nid][vid])
i_idx[nid].pop(vid, None)
idx.pop(vid, None)
if direction in [Direction.OUT, Direction.BOTH]:
del_index(self._oes, self._ies)
if direction in [Direction.IN, Direction.BOTH]:
del_index(self._ies, self._oes)
def search(
self,
vids: List[str],
direct: Direction = Direction.OUT,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> "MemoryGraph":
"""Search the graph from a vertex with specified parameters."""
subgraph = MemoryGraph()
for vid in vids:
self.__search(vid, direct, depth, fan, limit, 0, set(), subgraph)
return subgraph
def __search(
self,
vid: str,
direct: Direction,
depth: Optional[int],
fan: Optional[int],
limit: Optional[int],
_depth: int,
_visited: Set,
_subgraph: "MemoryGraph",
):
if vid in _visited or depth and _depth >= depth:
return
# visit vertex
if not self.has_vertex(vid):
return
_subgraph.upsert_vertex(self.get_vertex(vid))
_visited.add(vid)
# visit edges
nids = set()
for edge in self.get_neighbor_edges(vid, direct, fan):
if limit and _subgraph.edge_count >= limit:
return
# append edge success then visit new vertex
if _subgraph.append_edge(edge):
nid = edge.nid(vid)
if nid not in _visited:
nids.add(nid)
# next hop
for nid in nids:
self.__search(
nid, direct, depth, fan, limit, _depth + 1, _visited, _subgraph
)
def schema(self) -> Dict[str, Any]:
"""Return schema."""
return {
"schema": [
{
"type": "VERTEX",
"label": f"{self._vertex_label}",
"properties": [{"name": k} for k in self._vertex_prop_keys],
},
{
"type": "EDGE",
"label": f"{self._edge_label}",
"properties": [{"name": k} for k in self._edge_prop_keys],
},
]
}
def format(self) -> str:
"""Format graph to string."""
vs_str = "\n".join(v.format(self.vertex_label) for v in self.vertices())
es_str = "\n".join(
f"{self.get_vertex(e.sid).format(self.vertex_label)}"
f"{e.format(self.edge_label)}"
f"{self.get_vertex(e.tid).format(self.vertex_label)}"
for e in self.edges()
)
return f"Vertices:\n{vs_str}\n\nEdges:\n{es_str}"
def graphviz(self, name="g"):
"""View graphviz graph: https://dreampuf.github.io/GraphvizOnline."""
g = nx.MultiDiGraph()
for vertex in self.vertices():
g.add_node(vertex.vid)
for edge in self.edges():
triplet = edge.triplet(self.edge_label)
g.add_edge(triplet[0], triplet[2], label=triplet[1])
digraph = nx.nx_agraph.to_agraph(g).to_string()
digraph = digraph.replace('digraph ""', f"digraph {name}")
digraph = re.sub(r"key=\d+,?\s*", "", digraph)
return digraph

View File

@@ -0,0 +1,81 @@
"""Graph store base class."""
import json
import logging
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
logger = logging.getLogger(__name__)
class MemoryGraphStoreConfig(GraphStoreConfig):
"""Memory graph store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
edge_name_key: str = Field(
default="label",
description="The label of edge name, `label` by default.",
)
class MemoryGraphStore(GraphStoreBase):
"""Memory graph store."""
def __init__(self, graph_store_config: MemoryGraphStoreConfig):
"""Initialize MemoryGraphStore with a memory graph."""
self._edge_name_key = graph_store_config.edge_name_key
self._graph = MemoryGraph(edge_label=self._edge_name_key)
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert a triplet into the graph."""
self._graph.append_edge(Edge(sub, obj, **{self._edge_name_key: rel}))
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Retrieve triplets originating from a subject."""
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
return [(e.get_prop(self._edge_name_key), e.tid) for e in subgraph.edges()]
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete a specific triplet from the graph."""
self._graph.del_edges(sub, obj, **{self._edge_name_key: rel})
def drop(self):
"""Drop graph."""
self._graph = None
def get_schema(self, refresh: bool = False) -> str:
"""Return the graph schema as a JSON string."""
return json.dumps(self._graph.schema())
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
"""Return self."""
if not limit:
return self._graph
subgraph = MemoryGraph()
for count, edge in enumerate(self._graph.edges()):
if count >= limit:
break
subgraph.upsert_vertex(self._graph.get_vertex(edge.sid))
subgraph.upsert_vertex(self._graph.get_vertex(edge.tid))
subgraph.append_edge(edge)
count += 1
return subgraph
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
return self._graph.search(subs, direct, depth, fan, limit)
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
raise NotImplementedError("Query memory graph not allowed")

View File

@@ -0,0 +1,64 @@
"""Neo4j vector store."""
import logging
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Graph, MemoryGraph
logger = logging.getLogger(__name__)
class Neo4jStoreConfig(GraphStoreConfig):
"""Neo4j store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class Neo4jStore(GraphStoreBase):
"""Neo4j graph store."""
# todo: add neo4j implementation
def __init__(self, graph_store_config: Neo4jStoreConfig):
"""Initialize the Neo4jStore with connection details."""
pass
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert triplets."""
pass
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""
return []
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplets."""
pass
def drop(self):
"""Drop graph."""
pass
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""
return ""
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
return MemoryGraph()
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore the graph from given subjects up to a depth."""
return MemoryGraph()
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
return MemoryGraph()

View File

@@ -0,0 +1,239 @@
"""TuGraph vector store."""
import logging
import os
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.datasource.conn_tugraph import TuGraphConnector
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, MemoryGraph, Vertex
logger = logging.getLogger(__name__)
class TuGraphStoreConfig(GraphStoreConfig):
"""TuGraph store config."""
model_config = ConfigDict(arbitrary_types_allowed=True)
host: str = Field(
default="127.0.0.1",
description="TuGraph host",
)
port: int = Field(
default=7687,
description="TuGraph port",
)
username: str = Field(
default="admin",
description="login username",
)
password: str = Field(
default="123456",
description="login password",
)
vertex_type: str = Field(
default="entity",
description="The type of graph vertex, `entity` by default.",
)
edge_type: str = Field(
default="relation",
description="The type of graph edge, `relation` by default.",
)
edge_name_key: str = Field(
default="label",
description="The label of edge name, `label` by default.",
)
class TuGraphStore(GraphStoreBase):
"""TuGraph graph store."""
def __init__(self, config: TuGraphStoreConfig) -> None:
"""Initialize the TuGraphStore with connection details."""
self._host = os.getenv("TUGRAPH_HOST", "127.0.0.1") or config.host
self._port = int(os.getenv("TUGRAPH_PORT", 7687)) or config.port
self._username = os.getenv("TUGRAPH_USERNAME", "admin") or config.username
self._password = os.getenv("TUGRAPH_PASSWORD", "73@TuGraph") or config.password
self._node_label = (
os.getenv("TUGRAPH_VERTEX_TYPE", "entity") or config.vertex_type
)
self._edge_label = (
os.getenv("TUGRAPH_EDGE_TYPE", "relation") or config.edge_type
)
self.edge_name_key = (
os.getenv("TUGRAPH_EDGE_NAME_KEY", "label") or config.edge_name_key
)
self._graph_name = config.name
self.conn = TuGraphConnector.from_uri_db(
host=self._host,
port=self._port,
user=self._username,
pwd=self._password,
db_name=config.name,
)
self.conn.create_graph(graph_name=config.name)
self._create_schema()
def _check_label(self, elem_type: str):
result = self.conn.get_table_names()
if elem_type == "vertex":
return self._node_label in result["vertex_tables"]
if elem_type == "edge":
return self._edge_label in result["edge_tables"]
def _create_schema(self):
if not self._check_label("vertex"):
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._node_label}', "
f"'id', ['id',string,false])"
)
self.conn.run(create_vertex_gql)
if not self._check_label("edge"):
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_label}', '[["{self._node_label}",
"{self._node_label}"]]', ["id",STRING,false])"""
self.conn.run(create_edge_gql)
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
query = (
f"MATCH (n1:{self._node_label})-[r]->(n2:{self._node_label}) "
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
)
data = self.conn.run(query)
return [(record["rel"], record["obj"]) for record in data]
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
def escape_quotes(value: str) -> str:
"""Escape single and double quotes in a string for queries."""
return value.replace("'", "\\'").replace('"', '\\"')
subj_escaped = escape_quotes(subj)
rel_escaped = escape_quotes(rel)
obj_escaped = escape_quotes(obj)
subj_query = f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
obj_query = f"MERGE (n1:{self._node_label} {{id:'{obj_escaped}'}})"
rel_query = (
f"MERGE (n1:{self._node_label} {{id:'{subj_escaped}'}})"
f"-[r:{self._edge_label} {{id:'{rel_escaped}'}}]->"
f"(n2:{self._node_label} {{id:'{obj_escaped}'}})"
)
self.conn.run(query=subj_query)
self.conn.run(query=obj_query)
self.conn.run(query=rel_query)
def drop(self):
"""Delete Graph."""
self.conn.delete_graph(self._graph_name)
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
del_query = (
f"MATCH (n1:{self._node_label} {{id:'{sub}'}})"
f"-[r:{self._edge_label} {{id:'{rel}'}}]->"
f"(n2:{self._node_label} {{id:'{obj}'}}) DELETE n1,n2,r"
)
self.conn.run(query=del_query)
def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the graph store."""
query = "CALL dbms.graph.getGraphSchema()"
data = self.conn.run(query=query)
schema = data[0]["schema"]
return schema
def get_full_graph(self, limit: Optional[int] = None) -> MemoryGraph:
"""Get full graph."""
if not limit:
raise Exception("limit must be set")
return self.query(f"MATCH (n)-[r]-(m) RETURN n,m,r LIMIT {limit}")
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
if fan is not None:
raise ValueError("Fan functionality is not supported at this time.")
else:
depth_string = f"1..{depth}"
if depth is None:
depth_string = ".."
limit_string = f"LIMIT {limit}"
if limit is None:
limit_string = ""
query = (
f"MATCH p=(n:{self._node_label})"
f"-[r:{self._edge_label}*{depth_string}]-(m:{self._node_label}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}"
)
return self.query(query)
def query(self, query: str, **args) -> MemoryGraph:
"""Execute a query on graph."""
def _format_paths(paths):
formatted_paths = []
for path in paths:
formatted_path = []
nodes = list(path["p"].nodes)
rels = list(path["p"].relationships)
for i in range(len(nodes)):
formatted_path.append(nodes[i]._properties["id"])
if i < len(rels):
formatted_path.append(rels[i]._properties["id"])
formatted_paths.append(formatted_path)
return formatted_paths
def _format_query_data(data):
node_ids_set = set()
rels_set = set()
from neo4j import graph
for record in data:
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
node_ids_set.add(node_id)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
rels_set.add((src_id, dst_id, prop_id))
elif isinstance(value, graph.Path):
formatted_paths = _format_paths(data)
for path in formatted_paths:
for i in range(0, len(path), 2):
node_ids_set.add(path[i])
if i + 2 < len(path):
rels_set.add((path[i], path[i + 2], path[i + 1]))
nodes = [Vertex(node_id) for node_id in node_ids_set]
rels = [
Edge(src_id, dst_id, label=prop_id)
for (src_id, dst_id, prop_id) in rels_set
]
return {"nodes": nodes, "edges": rels}
result = self.conn.run(query=query)
graph = _format_query_data(result)
mg = MemoryGraph()
for vertex in graph["nodes"]:
mg.upsert_vertex(vertex)
for edge in graph["edges"]:
mg.append_edge(edge)
return mg