mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 11:39:11 +00:00
feat: add document structure into GraphRAG (#2033)
Co-authored-by: Appointat <kuda.czk@antgroup.com> Co-authored-by: tpoisonooo <khj.application@aliyun.com> Co-authored-by: vritser <vritser@163.com>
This commit is contained in:
@@ -1,11 +1,19 @@
|
||||
"""Define Classes about Community."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import AsyncGenerator, Iterator, List, Optional
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.graph import Graph
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +32,7 @@ class CommunityTree:
|
||||
"""Represents a community tree."""
|
||||
|
||||
|
||||
class CommunityStoreAdapter(ABC):
|
||||
class GraphStoreAdapter(ABC):
|
||||
"""Community Store Adapter."""
|
||||
|
||||
def __init__(self, graph_store: GraphStoreBase):
|
||||
@@ -44,6 +52,113 @@ class CommunityStoreAdapter(ABC):
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
||||
@abstractmethod
|
||||
def get_graph_config(self):
|
||||
"""Get config."""
|
||||
|
||||
@abstractmethod
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get vertex type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get edge type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_triplets(self, sub: str) -> List[tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
|
||||
@abstractmethod
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get document vertex."""
|
||||
|
||||
@abstractmethod
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
|
||||
@abstractmethod
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entity."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
):
|
||||
"""Upsert edge."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_chunks(self, chunk: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunk."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Insert triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_graph(self, graph: Graph) -> None:
|
||||
"""Insert graph."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_document(self, chunk_id: str) -> None:
|
||||
"""Delete document in graph store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self) -> None:
|
||||
"""Drop graph."""
|
||||
|
||||
@abstractmethod
|
||||
def create_graph(self, graph_name: str) -> None:
|
||||
"""Create graph."""
|
||||
|
||||
@abstractmethod
|
||||
def create_graph_label(self) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self) -> None:
|
||||
"""Truncate graph."""
|
||||
|
||||
@abstractmethod
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph."""
|
||||
|
||||
@abstractmethod
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
|
||||
|
||||
class CommunityMetastore(ABC):
|
||||
"""Community metastore class."""
|
||||
|
@@ -4,10 +4,7 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.community_metastore import (
|
||||
BuiltinCommunityMetastore,
|
||||
)
|
||||
@@ -21,23 +18,23 @@ class CommunityStore:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
community_store_adapter: CommunityStoreAdapter,
|
||||
graph_store_adapter: GraphStoreAdapter,
|
||||
community_summarizer: CommunitySummarizer,
|
||||
vector_store: VectorStoreBase,
|
||||
):
|
||||
"""Initialize the CommunityStore class."""
|
||||
self._community_store_adapter = community_store_adapter
|
||||
self._graph_store_adapter = graph_store_adapter
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
"""Discover communities."""
|
||||
community_ids = await self._community_store_adapter.discover_communities()
|
||||
community_ids = await self._graph_store_adapter.discover_communities()
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await self._community_store_adapter.get_community(community_id)
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
@@ -65,7 +62,7 @@ class CommunityStore:
|
||||
self._community_summarizer.truncate()
|
||||
|
||||
logger.info("Truncate graph")
|
||||
self._community_store_adapter.graph_store.truncate()
|
||||
self._graph_store_adapter.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop community store."""
|
||||
@@ -76,4 +73,4 @@ class CommunityStore:
|
||||
self._community_summarizer.drop()
|
||||
|
||||
logger.info("Remove graph")
|
||||
self._community_store_adapter.graph_store.drop()
|
||||
self._graph_store_adapter.drop()
|
||||
|
@@ -1,28 +1,29 @@
|
||||
"""CommunityStoreAdapter factory."""
|
||||
"""GraphStoreAdapter factory."""
|
||||
|
||||
import logging
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
|
||||
TuGraphCommunityStoreAdapter,
|
||||
from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
|
||||
TuGraphStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityStoreAdapterFactory:
|
||||
class GraphStoreAdapterFactory:
|
||||
"""Factory for community store adapter."""
|
||||
|
||||
@staticmethod
|
||||
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
|
||||
"""Create a CommunityStoreAdapter instance.
|
||||
def create(graph_store: GraphStoreBase) -> GraphStoreAdapter:
|
||||
"""Create a GraphStoreAdapter instance.
|
||||
|
||||
Args:
|
||||
- graph_store_type: graph store type Memory, TuGraph, Neo4j
|
||||
"""
|
||||
if isinstance(graph_store, TuGraphStore):
|
||||
return TuGraphCommunityStoreAdapter(graph_store)
|
||||
return TuGraphStoreAdapter(graph_store)
|
||||
else:
|
||||
raise Exception(
|
||||
"create community store adapter for %s failed",
|
||||
|
@@ -0,0 +1,186 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, Iterator, List, Optional, Tuple
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
from dbgpt.storage.graph_store.memgraph_store import (
|
||||
MemoryGraphStore,
|
||||
MemoryGraphStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
"""MemGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
def __init__(self, enable_summary: bool = False):
|
||||
"""Initialize MemGraph Community Store Adapter."""
|
||||
self._graph_store = MemoryGraphStore(MemoryGraphStoreConfig())
|
||||
self._enable_summary = enable_summary
|
||||
|
||||
super().__init__(self._graph_store)
|
||||
|
||||
# Create the graph
|
||||
self.create_graph(self._graph_store.get_config().name)
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
pass
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
pass
|
||||
|
||||
def get_graph_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self._graph_store.get_config()
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
# raise NotImplementedError("Memory graph store does not have vertex type")
|
||||
return ""
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
# raise NotImplementedError("Memory graph store does not have edge type")
|
||||
return ""
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
subgraph = self.explore([subj], direct=Direction.OUT, depth=1)
|
||||
return [(e.name, e.tid) for e in subgraph.edges()]
|
||||
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get the document vertex in the graph."""
|
||||
raise NotImplementedError("Memory graph store does not have document vertex")
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
return json.dumps(self._graph_store._graph.schema())
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
return self._graph_store._graph
|
||||
|
||||
subgraph = MemoryGraph()
|
||||
for count, edge in enumerate(self._graph_store._graph.edges()):
|
||||
if count >= limit:
|
||||
break
|
||||
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.sid))
|
||||
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.tid))
|
||||
subgraph.append_edge(edge)
|
||||
count += 1
|
||||
return subgraph
|
||||
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entities."""
|
||||
pass
|
||||
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
) -> None:
|
||||
"""Upsert edges."""
|
||||
pass
|
||||
|
||||
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunks."""
|
||||
pass
|
||||
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
pass
|
||||
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
self._graph_store._graph.append_edge(Edge(subj, obj, rel))
|
||||
|
||||
def upsert_graph(self, graph: Graph) -> None:
|
||||
"""Add graph to the graph store.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph to be added.
|
||||
"""
|
||||
for vertex in graph.vertices():
|
||||
self._graph_store._graph.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self._graph_store._graph.append_edge(edge)
|
||||
|
||||
def delete_document(self, chunk_ids: str) -> None:
|
||||
"""Delete document in the graph."""
|
||||
pass
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
self._graph_store._graph.del_edges(sub, obj, rel)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self._graph_store._graph = None
|
||||
|
||||
def create_graph(self, graph_name: str):
|
||||
"""Create a graph."""
|
||||
pass
|
||||
|
||||
def create_graph_label(
|
||||
self,
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
self._graph_store._graph.truncate()
|
||||
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph.
|
||||
|
||||
Args:
|
||||
graph_elem_type (GraphElemType): The type of the graph element.
|
||||
|
||||
Returns:
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
pass
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int | None = None,
|
||||
fan: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
pass
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
pass
|
@@ -1,52 +0,0 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self._graph_store.query(
|
||||
"CALL db.plugin.callPlugin"
|
||||
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
|
||||
f"[r:{self._graph_store.get_edge_type()}]-"
|
||||
f"(m:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self._graph_store.aquery(query)
|
||||
all_edge_graph = self._graph_store.aquery(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
808
dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Normal file
808
dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Normal file
@@ -0,0 +1,808 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_QUERY_LIMIT = 1000
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
def __init__(self, graph_store: TuGraphStore):
|
||||
"""Initialize TuGraph Community Store Adapter."""
|
||||
super().__init__(graph_store)
|
||||
|
||||
# Create the graph
|
||||
self.create_graph(self.graph_store.get_config().name)
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self.query(
|
||||
"CALL db.plugin.callPlugin('CPP',"
|
||||
"'leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self.get_vertex_type()}) WHERE n._community_id = "
|
||||
f"'{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self.get_vertex_type()})-"
|
||||
f"[r:{self.get_edge_type()}]-"
|
||||
f"(m:{self.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self.query(query)
|
||||
all_edge_graph = self.query(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
||||
|
||||
@property
|
||||
def graph_store(self) -> TuGraphStore:
|
||||
"""Get the graph store."""
|
||||
return self._graph_store
|
||||
|
||||
def get_graph_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self.graph_store.get_config()
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
return GraphElemType.ENTITY.value
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
return GraphElemType.RELATION.value
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
triplet_query = (
|
||||
f"MATCH (n1:{GraphElemType.ENTITY.value})-[r]->(n2:"
|
||||
f"{GraphElemType.ENTITY.value}) "
|
||||
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
|
||||
)
|
||||
data = self.graph_store.conn.run(triplet_query)
|
||||
return [(record["rel"], record["obj"]) for record in data]
|
||||
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get the document vertex in the graph."""
|
||||
gql = f"""MATCH (n) WHERE n.id = {doc_name} RETURN n"""
|
||||
graph = self.query(gql)
|
||||
vertex = graph.get_vertex(doc_name)
|
||||
return vertex
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
query = "CALL dbms.graph.getGraphSchema()"
|
||||
data = self.graph_store.conn.run(query=query)
|
||||
schema = data[0]["schema"]
|
||||
return schema
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
limit = self.MAX_QUERY_LIMIT
|
||||
if limit <= 0:
|
||||
raise ValueError("Limit must be greater than 0.")
|
||||
graph_result = self.query(
|
||||
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
|
||||
white_list=["_community_id"],
|
||||
)
|
||||
full_graph = MemoryGraph()
|
||||
for vertex in graph_result.vertices():
|
||||
full_graph.upsert_vertex(vertex)
|
||||
for edge in graph_result.edges():
|
||||
full_graph.append_edge(edge)
|
||||
return full_graph
|
||||
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entities."""
|
||||
entity_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(entity.vid),
|
||||
"name": self.graph_store._escape_quotes(entity.name),
|
||||
"description": self.graph_store._escape_quotes(
|
||||
entity.get_prop("description")
|
||||
)
|
||||
or "",
|
||||
"_document_id": "0",
|
||||
"_chunk_id": "0",
|
||||
"_community_id": "0",
|
||||
}
|
||||
for entity in entities
|
||||
]
|
||||
entity_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.ENTITY.value}", '
|
||||
f"[{self._parser(entity_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=entity_query)
|
||||
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
) -> None:
|
||||
"""Upsert edges."""
|
||||
edge_list = [
|
||||
{
|
||||
"sid": self.graph_store._escape_quotes(edge.sid),
|
||||
"tid": self.graph_store._escape_quotes(edge.tid),
|
||||
"id": self.graph_store._escape_quotes(edge.name),
|
||||
"name": self.graph_store._escape_quotes(edge.name),
|
||||
"description": self.graph_store._escape_quotes(
|
||||
edge.get_prop("description")
|
||||
)
|
||||
or "",
|
||||
"_chunk_id": self.graph_store._escape_quotes(edge.get_prop("_chunk_id"))
|
||||
or "",
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
relation_query = f"""CALL db.upsertEdge("{edge_type}",
|
||||
{{type:"{src_type}", key:"sid"}},
|
||||
{{type:"{dst_type}", key:"tid"}},
|
||||
[{self._parser(edge_list)}])"""
|
||||
self.graph_store.conn.run(query=relation_query)
|
||||
|
||||
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunks."""
|
||||
chunk_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(chunk.vid),
|
||||
"name": self.graph_store._escape_quotes(chunk.name),
|
||||
"content": self.graph_store._escape_quotes(chunk.get_prop("content")),
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
chunk_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.CHUNK.value}", '
|
||||
f"[{self._parser(chunk_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=chunk_query)
|
||||
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
document_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(document.vid),
|
||||
"name": self.graph_store._escape_quotes(document.name),
|
||||
"content": self.graph_store._escape_quotes(document.get_prop("content"))
|
||||
or "",
|
||||
}
|
||||
for document in documents
|
||||
]
|
||||
document_query = (
|
||||
"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.DOCUMENT.value}", '
|
||||
f"[{self._parser(document_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=document_query)
|
||||
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
subj_escaped = subj.replace("'", "\\'").replace('"', '\\"')
|
||||
rel_escaped = rel.replace("'", "\\'").replace('"', '\\"')
|
||||
obj_escaped = obj.replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
vertex_query = f"""CALL db.upsertVertex(
|
||||
'{GraphElemType.ENTITY.value}',
|
||||
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
|
||||
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
'{GraphElemType.RELATION.value}',
|
||||
{{type:"{GraphElemType.ENTITY.value}",key:"sid"}},
|
||||
{{type:"{GraphElemType.ENTITY.value}", key:"tid"}},
|
||||
[{{sid:"{subj_escaped}",
|
||||
tid: "{obj_escaped}",
|
||||
id:"{rel_escaped}",
|
||||
name: "{rel_escaped}"}}])"""
|
||||
|
||||
self.graph_store.conn.run(query=vertex_query)
|
||||
self.graph_store.conn.run(query=edge_query)
|
||||
|
||||
def upsert_graph(self, graph: MemoryGraph) -> None:
|
||||
"""Add graph to the graph store.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph to be added.
|
||||
"""
|
||||
# Get the iterators of all the vertices and the edges from the graph
|
||||
documents: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type")
|
||||
== GraphElemType.DOCUMENT.value
|
||||
)
|
||||
chunks: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.CHUNK.value
|
||||
)
|
||||
entities: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.ENTITY.value
|
||||
)
|
||||
doc_include_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.DOCUMENT_INCLUDE_CHUNK.value
|
||||
)
|
||||
chunk_include_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_INCLUDE_CHUNK.value
|
||||
)
|
||||
chunk_include_entity: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_INCLUDE_ENTITY.value
|
||||
)
|
||||
chunk_next_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_NEXT_CHUNK.value
|
||||
)
|
||||
relation: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type") == GraphElemType.RELATION.value
|
||||
)
|
||||
|
||||
# Upsert the vertices and the edges to the graph store
|
||||
self.upsert_entities(entities)
|
||||
self.upsert_chunks(chunks)
|
||||
self.upsert_documents(documents)
|
||||
self.upsert_edge(
|
||||
doc_include_chunk,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.DOCUMENT.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_include_chunk,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_include_entity,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_next_chunk,
|
||||
GraphElemType.NEXT.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
relation,
|
||||
GraphElemType.RELATION.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
)
|
||||
|
||||
def delete_document(self, chunk_ids: str) -> None:
|
||||
"""Delete document in the graph."""
|
||||
chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")]
|
||||
del_chunk_gql = (
|
||||
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
|
||||
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"
|
||||
)
|
||||
del_relation_gql = (
|
||||
f"MATCH(m:{GraphElemType.ENTITY.value})-[r:"
|
||||
f"{GraphElemType.RELATION.value}]-(n:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE r._chunk_id IN {chunkids_list} DELETE r"
|
||||
)
|
||||
delete_only_vertex = "MATCH (n) WHERE NOT EXISTS((n)-[]-()) DELETE n"
|
||||
self.graph_store.conn.run(del_chunk_gql)
|
||||
self.graph_store.conn.run(del_relation_gql)
|
||||
self.graph_store.conn.run(delete_only_vertex)
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
del_query = (
|
||||
f"MATCH (n1:{GraphElemType.ENTITY.value} {{id:'{sub}'}})"
|
||||
f"-[r:{GraphElemType.RELATION.value} {{id:'{rel}'}}]->"
|
||||
f"(n2:{GraphElemType.ENTITY.value} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
)
|
||||
self.graph_store.conn.run(query=del_query)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self.graph_store.conn.delete_graph(self.get_graph_config().name)
|
||||
|
||||
def create_graph(self, graph_name: str):
|
||||
"""Create a graph."""
|
||||
self.graph_store.conn.create_graph(graph_name=graph_name)
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_propertity_schema(
|
||||
name: str,
|
||||
type: str = "STRING",
|
||||
optional: bool = False,
|
||||
index: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, str | bool]:
|
||||
"""Format the property for TuGraph.
|
||||
|
||||
Args:
|
||||
name: The name of the property.
|
||||
type: The type of the property.
|
||||
optional: The optional of the property.
|
||||
index: The index of the property.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The formatted property.
|
||||
"""
|
||||
property: Dict[str, str | bool] = {
|
||||
"name": name,
|
||||
"type": type,
|
||||
"optional": optional,
|
||||
}
|
||||
|
||||
if index is not None:
|
||||
property["index"] = index
|
||||
|
||||
# Add any additional keyword arguments to the property dictionary
|
||||
property.update(kwargs)
|
||||
return property
|
||||
|
||||
# Create the graph label for document vertex
|
||||
document_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for chunk vertex
|
||||
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("content", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for entity vertex
|
||||
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for relation edge
|
||||
edge_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for include edge
|
||||
include_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for next edge
|
||||
next_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
|
||||
)
|
||||
|
||||
if self.graph_store._enable_summary:
|
||||
self.graph_store._upload_plugin()
|
||||
|
||||
def create_graph_label(
|
||||
self,
|
||||
graph_elem_type: GraphElemType,
|
||||
graph_properties: List[Dict[str, Union[str, bool]]],
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
if graph_elem_type.is_vertex(): # vertex
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "VERTEX",
|
||||
"primary": "id",
|
||||
"properties": graph_properties,
|
||||
})
|
||||
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
|
||||
|
||||
gql_check_exist = (
|
||||
f"""CALL db.getLabelSchema('VERTEX', '{graph_elem_type.value}')"""
|
||||
)
|
||||
else: # edge
|
||||
|
||||
def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
|
||||
"""Define the edge direction.
|
||||
|
||||
`include` edge: document -> chunk, chunk -> entity
|
||||
`next` edge: chunk -> chunk
|
||||
`relation` edge: entity -> entity
|
||||
"""
|
||||
if graph_elem_type.is_vertex():
|
||||
raise ValueError("The graph element type must be an edge.")
|
||||
if graph_elem_type == GraphElemType.INCLUDE:
|
||||
return [
|
||||
[GraphElemType.DOCUMENT.value, GraphElemType.CHUNK.value],
|
||||
[GraphElemType.CHUNK.value, GraphElemType.ENTITY.value],
|
||||
[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value],
|
||||
]
|
||||
elif graph_elem_type == GraphElemType.NEXT:
|
||||
return [[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value]]
|
||||
elif graph_elem_type == GraphElemType.RELATION:
|
||||
return [[GraphElemType.ENTITY.value, GraphElemType.ENTITY.value]]
|
||||
else:
|
||||
raise ValueError("Invalid graph element type.")
|
||||
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "EDGE",
|
||||
"constraints": edge_direction(graph_elem_type),
|
||||
"properties": graph_properties,
|
||||
})
|
||||
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
|
||||
|
||||
gql_check_exist = (
|
||||
f"""CALL db.getLabelSchema('EDGE', '{graph_elem_type.value}')"""
|
||||
)
|
||||
|
||||
# Make sure the graph label is identical
|
||||
try:
|
||||
self.graph_store.conn.run(
|
||||
gql_check_exist
|
||||
) # if not exist, qurying raises an exception
|
||||
except Exception:
|
||||
self.graph_store.conn.run(gql) # create the graph label
|
||||
return
|
||||
|
||||
logger.info(f"Graph label {graph_elem_type.value} already exists.")
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
gql = "MATCH (n) DELETE n"
|
||||
self.graph_store.conn.run(gql)
|
||||
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph.
|
||||
|
||||
Args:
|
||||
graph_elem_type (GraphElemType): The type of the graph element.
|
||||
|
||||
Returns:
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
|
||||
|
||||
if graph_elem_type.is_vertex():
|
||||
return graph_elem_type in vertex_tables
|
||||
else:
|
||||
return graph_elem_type in edge_tables
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
] = "knowledge_graph",
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
if not subs:
|
||||
return MemoryGraph()
|
||||
|
||||
if depth is None or depth < 0 or depth > self.MAX_HIERARCHY_LEVEL:
|
||||
# TODO: to be discussed, be none or MAX_HIERARCHY_LEVEL
|
||||
# depth_string = ".."
|
||||
depth = self.MAX_HIERARCHY_LEVEL
|
||||
depth_string = f"1..{depth}"
|
||||
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
else:
|
||||
limit_string = f"LIMIT {limit}"
|
||||
|
||||
if search_scope == "knowledge_graph":
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
else:
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
|
||||
f"{rel}(m:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE n.id IN {subs} RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query)
|
||||
else:
|
||||
graph = MemoryGraph()
|
||||
|
||||
for sub in subs:
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-"
|
||||
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS '{sub}' "
|
||||
f"RETURN p {limit_string}"
|
||||
) # if it contains the subjects
|
||||
result = self.query(query)
|
||||
for vertex in result.vertices():
|
||||
graph.upsert_vertex(vertex)
|
||||
for edge in result.edges():
|
||||
graph.append_edge(edge)
|
||||
|
||||
return graph
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph.
|
||||
|
||||
white_list: List[str] = kwargs.get("white_list", []), which contains the white
|
||||
list of properties and filters the properties that are not in the white list.
|
||||
"""
|
||||
query_result = self.graph_store.conn.run(query=query)
|
||||
white_list: List[str] = kwargs.get(
|
||||
"white_list",
|
||||
[
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"_document_id",
|
||||
"_chunk_id",
|
||||
"_community_id",
|
||||
],
|
||||
)
|
||||
vertices, edges = self._get_nodes_edges_from_queried_data(
|
||||
query_result, white_list
|
||||
)
|
||||
mg = MemoryGraph()
|
||||
for vertex in vertices:
|
||||
mg.upsert_vertex(vertex)
|
||||
for edge in edges:
|
||||
mg.append_edge(edge)
|
||||
return mg
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
from neo4j import graph
|
||||
|
||||
async for record in self.graph_store.conn.run_stream(query):
|
||||
mg = MemoryGraph()
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
description = value._properties["description"]
|
||||
vertex = Vertex(vid=node_id, name=node_id, description=description)
|
||||
mg.upsert_vertex(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
edge_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
assert edge_nodes and edge_nodes[0] and edge_nodes[1]
|
||||
src_id = edge_nodes[0]._properties["id"]
|
||||
dst_id = edge_nodes[1]._properties["id"]
|
||||
description = value._properties["description"]
|
||||
edge = Edge(
|
||||
sid=src_id, tid=dst_id, name=prop_id, description=description
|
||||
)
|
||||
mg.append_edge(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
nodes = list(record["p"].nodes)
|
||||
rels = list(record["p"].relationships)
|
||||
formatted_path = []
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append({
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
})
|
||||
if i < len(rels):
|
||||
formatted_path.append({
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
})
|
||||
for i in range(0, len(formatted_path), 2):
|
||||
mg.upsert_vertex(
|
||||
Vertex(
|
||||
vid=formatted_path[i]["id"],
|
||||
name=formatted_path[i]["id"],
|
||||
description=formatted_path[i]["description"],
|
||||
)
|
||||
)
|
||||
if i + 2 < len(formatted_path):
|
||||
mg.append_edge(
|
||||
Edge(
|
||||
sid=formatted_path[i]["id"],
|
||||
tid=formatted_path[i + 2]["id"],
|
||||
name=formatted_path[i + 1]["id"],
|
||||
description=formatted_path[i + 1]["description"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
vertex = Vertex(
|
||||
vid="json_node", name="json_node", description=value
|
||||
)
|
||||
mg.upsert_vertex(vertex)
|
||||
yield mg
|
||||
|
||||
def _get_nodes_edges_from_queried_data(
|
||||
self,
|
||||
data: List[Dict[str, Any]],
|
||||
white_prop_list: List[str],
|
||||
) -> Tuple[List[Vertex], List[Edge]]:
|
||||
"""Format the query data.
|
||||
|
||||
Args:
|
||||
data: The data to be formatted.
|
||||
white_prop_list: The white list of properties.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Vertex], List[Edge]]: The formatted vertices and edges.
|
||||
"""
|
||||
vertex_list: List[Vertex] = []
|
||||
edge_list: List[Edge] = []
|
||||
|
||||
# Remove id, src_id, dst_id and name from the white list
|
||||
# to avoid duplication in the initialisation of the vertex and edge
|
||||
_white_list = [
|
||||
prop
|
||||
for prop in white_prop_list
|
||||
if prop not in ["id", "src_id", "dst_id", "name"]
|
||||
]
|
||||
|
||||
from neo4j import graph
|
||||
|
||||
def filter_properties(
|
||||
properties: dict[str, Any], white_list: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Filter the properties.
|
||||
|
||||
It will remove the properties that are not in the white list.
|
||||
The expected propertities are:
|
||||
entity_properties = ["id", "name", "description", "_document_id",
|
||||
"_chunk_id", "_community_id"]
|
||||
edge_properties = ["id", "name", "description", "_chunk_id"]
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in properties.items()
|
||||
if (not key.startswith("_") and key not in ["id", "name"])
|
||||
or key in white_list
|
||||
}
|
||||
|
||||
# Parse the data to nodes and relationships
|
||||
for record in data:
|
||||
for value in record.values():
|
||||
if isinstance(value, graph.Node):
|
||||
assert value._properties.get("id")
|
||||
vertex = Vertex(
|
||||
vid=value._properties.get("id", ""),
|
||||
name=value._properties.get("name"),
|
||||
**filter_properties(value._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
# TODO: Do we really need to check it every time?
|
||||
vertex_list.append(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
for node in value.nodes: # num of nodes is 2
|
||||
assert node and node._properties
|
||||
vertex = Vertex(
|
||||
vid=node._properties.get("id", ""),
|
||||
name=node._properties.get("name"),
|
||||
**filter_properties(node._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
|
||||
assert value.nodes and value.nodes[0] and value.nodes[1]
|
||||
edge = Edge(
|
||||
sid=value.nodes[0]._properties.get("id", ""),
|
||||
tid=value.nodes[1]._properties.get("id", ""),
|
||||
name=value._properties.get("name", ""),
|
||||
**filter_properties(value._properties, _white_list),
|
||||
)
|
||||
if edge not in edge_list:
|
||||
edge_list.append(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
for rel in value.relationships:
|
||||
for node in rel.nodes: # num of nodes is 2
|
||||
assert node and node._properties
|
||||
vertex = Vertex(
|
||||
vid=node._properties.get("id", ""),
|
||||
name=node._properties.get("name"),
|
||||
**filter_properties(node._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
|
||||
assert rel.nodes and rel.nodes[0] and rel.nodes[1]
|
||||
edge = Edge(
|
||||
sid=rel.nodes[0]._properties.get("id", ""),
|
||||
tid=rel.nodes[1]._properties.get("id", ""),
|
||||
name=rel._properties.get("name", ""),
|
||||
**filter_properties(rel._properties, _white_list),
|
||||
)
|
||||
if edge not in edge_list:
|
||||
edge_list.append(edge)
|
||||
|
||||
else: # json_node
|
||||
vertex = Vertex(
|
||||
vid="json_node",
|
||||
name="json_node",
|
||||
**filter_properties({"description": value}, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
return vertex_list, edge_list
|
||||
|
||||
def _parser(self, entity_list: List[Dict[str, Any]]) -> str:
|
||||
"""Parse entities to string."""
|
||||
formatted_nodes = [
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
|
||||
for k, v in node.items()
|
||||
)
|
||||
+ "}"
|
||||
for node in entity_list
|
||||
]
|
||||
return f"""{", ".join(formatted_nodes)}"""
|
Reference in New Issue
Block a user