mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
feat: add document structure into GraphRAG (#2033)
Co-authored-by: Appointat <kuda.czk@antgroup.com> Co-authored-by: tpoisonooo <khj.application@aliyun.com> Co-authored-by: vritser <vritser@163.com>
This commit is contained in:
@@ -157,12 +157,15 @@ EXECUTE_LOCAL_COMMANDS=False
|
||||
#*******************************************************************#
|
||||
VECTOR_STORE_TYPE=Chroma
|
||||
GRAPH_STORE_TYPE=TuGraph
|
||||
GRAPH_COMMUNITY_SUMMARY_ENABLED=True
|
||||
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
|
||||
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
|
||||
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
|
||||
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0
|
||||
|
||||
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
|
||||
ENABLE_TRIPLET_GRAPH=True # enable the graph search for triplets
|
||||
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks
|
||||
|
||||
### Chroma vector db config
|
||||
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data
|
||||
|
||||
|
@@ -213,8 +213,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
# Vector Store Configuration
|
||||
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
|
||||
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = (
|
||||
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
|
||||
self.ENABLE_GRAPH_COMMUNITY_SUMMARY = (
|
||||
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
|
||||
)
|
||||
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
|
||||
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")
|
||||
|
@@ -12,7 +12,6 @@ from dbgpt.app.knowledge.document_db import (
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from dbgpt.app.knowledge.request.request import (
|
||||
ChunkEditRequest,
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
DocumentRecallTestRequest,
|
||||
@@ -650,12 +649,17 @@ class KnowledgeService:
|
||||
{
|
||||
"id": node.vid,
|
||||
"communityId": node.get_prop("_community_id"),
|
||||
"name": node.vid,
|
||||
"type": "",
|
||||
"name": node.name,
|
||||
"type": node.get_prop("type") or "",
|
||||
}
|
||||
)
|
||||
for edge in graph.edges():
|
||||
res["edges"].append(
|
||||
{"source": edge.sid, "target": edge.tid, "name": edge.name, "type": ""}
|
||||
{
|
||||
"source": edge.sid,
|
||||
"target": edge.tid,
|
||||
"name": edge.name,
|
||||
"type": edge.get_prop("type") or "",
|
||||
}
|
||||
)
|
||||
return res
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""TuGraph Connector."""
|
||||
|
||||
import json
|
||||
from typing import Dict, Generator, List, cast
|
||||
from typing import Dict, Generator, List, Tuple, cast
|
||||
|
||||
from .base import BaseConnector
|
||||
|
||||
@@ -21,8 +21,7 @@ class TuGraphConnector(BaseConnector):
|
||||
self._session = None
|
||||
|
||||
def create_graph(self, graph_name: str) -> None:
|
||||
"""Create a new graph."""
|
||||
# run the query to get vertex labels
|
||||
"""Create a new graph in the database if it doesn't already exist."""
|
||||
try:
|
||||
with self._driver.session(database="default") as session:
|
||||
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
|
||||
@@ -32,10 +31,10 @@ class TuGraphConnector(BaseConnector):
|
||||
f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}")
|
||||
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
|
||||
|
||||
def delete_graph(self, graph_name: str) -> None:
|
||||
"""Delete a graph."""
|
||||
"""Delete a graph in the database if it exists."""
|
||||
with self._driver.session(database="default") as session:
|
||||
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
|
||||
exists = any(item["graph_name"] == graph_name for item in graph_list)
|
||||
@@ -61,17 +60,20 @@ class TuGraphConnector(BaseConnector):
|
||||
"`pip install neo4j`"
|
||||
) from err
|
||||
|
||||
def get_table_names(self) -> Dict[str, List[str]]:
|
||||
def get_table_names(self) -> Tuple[List[str], List[str]]:
|
||||
"""Get all table names from the TuGraph by Neo4j driver."""
|
||||
# run the query to get vertex labels
|
||||
with self._driver.session(database=self._graph) as session:
|
||||
v_result = session.run("CALL db.vertexLabels()").data()
|
||||
v_data = [table_name["label"] for table_name in v_result]
|
||||
# Run the query to get vertex labels
|
||||
raw_vertex_labels: Dict[str, str] = session.run(
|
||||
"CALL db.vertexLabels()"
|
||||
).data()
|
||||
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
|
||||
|
||||
# run the query to get edge labels
|
||||
e_result = session.run("CALL db.edgeLabels()").data()
|
||||
e_data = [table_name["label"] for table_name in e_result]
|
||||
return {"vertex_tables": v_data, "edge_tables": e_data}
|
||||
# Run the query to get edge labels
|
||||
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
|
||||
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
|
||||
|
||||
return vertex_labels, edge_labels
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
@@ -100,7 +102,7 @@ class TuGraphConnector(BaseConnector):
|
||||
result = session.run(query)
|
||||
return list(result)
|
||||
except Exception as e:
|
||||
raise Exception(f"Query execution failed: {e}")
|
||||
raise Exception(f"Query execution failed: {e}\nQuery: {query}") from e
|
||||
|
||||
def run_stream(self, query: str) -> Generator:
|
||||
"""Run GQL."""
|
||||
@@ -109,11 +111,15 @@ class TuGraphConnector(BaseConnector):
|
||||
yield from result
|
||||
|
||||
def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
|
||||
"""Get fields about specified graph.
|
||||
"""Retrieve the column for a specified vertex or edge table in the graph db.
|
||||
|
||||
This function queries the schema of a given table (vertex or edge) and returns
|
||||
detailed information about its columns (properties).
|
||||
|
||||
Args:
|
||||
table_name (str): table name (graph name)
|
||||
table_type (str): table type (vertex or edge)
|
||||
|
||||
Returns:
|
||||
columns: List[Dict], which contains name: str, type: str,
|
||||
default_expression: str, is_in_primary_key: bool, comment: str
|
||||
@@ -146,8 +152,8 @@ class TuGraphConnector(BaseConnector):
|
||||
"""Get table indexes about specified table.
|
||||
|
||||
Args:
|
||||
table_name:(str) table name
|
||||
table_type:(str)'vertex' | 'edge'
|
||||
table_name (str): table name
|
||||
table_type (str): 'vertex' | 'edge'
|
||||
Returns:
|
||||
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
|
||||
"""
|
||||
|
@@ -65,7 +65,9 @@ class GraphExtractor(LLMExtractor):
|
||||
match = re.match(r"\((.*?)#(.*?)\)", line)
|
||||
if match:
|
||||
name, summary = [part.strip() for part in match.groups()]
|
||||
graph.upsert_vertex(Vertex(name, description=summary))
|
||||
graph.upsert_vertex(
|
||||
Vertex(name, description=summary, vertex_type="entity")
|
||||
)
|
||||
elif current_section == "Relationships":
|
||||
match = re.match(r"\((.*?)#(.*?)#(.*?)#(.*?)\)", line)
|
||||
if match:
|
||||
@@ -74,7 +76,13 @@ class GraphExtractor(LLMExtractor):
|
||||
]
|
||||
edge_count += 1
|
||||
graph.append_edge(
|
||||
Edge(source, target, name, description=summary)
|
||||
Edge(
|
||||
source,
|
||||
target,
|
||||
name,
|
||||
description=summary,
|
||||
edge_type="relation",
|
||||
)
|
||||
)
|
||||
|
||||
if limit and edge_count >= limit:
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""KeywordExtractor class."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -39,12 +40,15 @@ class KeywordExtractor(LLMExtractor):
|
||||
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]:
|
||||
keywords = set()
|
||||
|
||||
for part in text.split(";"):
|
||||
for s in part.strip().split(","):
|
||||
keyword = s.strip()
|
||||
if keyword:
|
||||
keywords.add(keyword)
|
||||
if limit and len(keywords) >= limit:
|
||||
return list(keywords)
|
||||
lines = text.replace(":", "\n").split("\n")
|
||||
|
||||
for line in lines:
|
||||
for part in line.split(";"):
|
||||
for s in part.strip().split(","):
|
||||
keyword = s.strip()
|
||||
if keyword:
|
||||
keywords.add(keyword)
|
||||
if limit and len(keywords) >= limit:
|
||||
return list(keywords)
|
||||
|
||||
return list(keywords)
|
||||
|
@@ -128,7 +128,7 @@ class VectorStoreConnector:
|
||||
|
||||
def __rewrite_index_store_type(self, index_store_type):
|
||||
# Rewrite Knowledge Graph Type
|
||||
if CFG.GRAPH_COMMUNITY_SUMMARY_ENABLED:
|
||||
if CFG.ENABLE_GRAPH_COMMUNITY_SUMMARY:
|
||||
if index_store_type == "KnowledgeGraph":
|
||||
return "CommunitySummaryKnowledgeGraph"
|
||||
return index_store_type
|
||||
|
@@ -1,11 +1,11 @@
|
||||
"""Graph store base class."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.storage.graph_store.graph import Direction, Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,78 +23,36 @@ class GraphStoreConfig(BaseModel):
|
||||
default=None,
|
||||
description="The embedding function of graph store, optional.",
|
||||
)
|
||||
summary_enabled: bool = Field(
|
||||
enable_summary: bool = Field(
|
||||
default=False,
|
||||
description="Enable graph community summary or not.",
|
||||
)
|
||||
enable_document_graph: bool = Field(
|
||||
default=True,
|
||||
description="Enable document graph search or not.",
|
||||
)
|
||||
enable_triplet_graph: bool = Field(
|
||||
default=True,
|
||||
description="Enable knowledge graph search or not.",
|
||||
)
|
||||
|
||||
|
||||
class GraphStoreBase(ABC):
|
||||
"""Graph store base class."""
|
||||
|
||||
def __init__(self, config: GraphStoreConfig):
|
||||
"""Initialize graph store."""
|
||||
self._config = config
|
||||
self._conn = None
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> GraphStoreConfig:
|
||||
"""Get the graph store config."""
|
||||
|
||||
@abstractmethod
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
def _escape_quotes(self, text: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
|
||||
@abstractmethod
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Add triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_graph(self, graph: Graph):
|
||||
"""Add graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
|
||||
@abstractmethod
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
|
||||
@abstractmethod
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Graph:
|
||||
"""Explore on graph."""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query."""
|
||||
|
||||
def aquery(self, query: str, **args) -> Graph:
|
||||
"""Async execute a query."""
|
||||
return self.query(query, **args)
|
||||
|
||||
@abstractmethod
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute stream query."""
|
||||
# @abstractmethod
|
||||
# def _paser(self, entities: List[Vertex]) -> str:
|
||||
# """Parse entities to string."""
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Graph store factory."""
|
||||
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Graph definition."""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -6,13 +7,41 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
|
||||
|
||||
import networkx as nx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphElemType(Enum):
|
||||
"""Type of element in graph."""
|
||||
|
||||
DOCUMENT = "document"
|
||||
CHUNK = "chunk"
|
||||
ENTITY = "entity" # view as general vertex in the general case
|
||||
RELATION = "relation" # view as general edge in the general case
|
||||
INCLUDE = "include"
|
||||
NEXT = "next"
|
||||
|
||||
DOCUMENT_INCLUDE_CHUNK = "document_include_chunk"
|
||||
CHUNK_INCLUDE_CHUNK = "chunk_include_chunk"
|
||||
CHUNK_INCLUDE_ENTITY = "chunk_include_entity"
|
||||
CHUNK_NEXT_CHUNK = "chunk_next_chunk"
|
||||
|
||||
def is_vertex(self) -> bool:
|
||||
"""Check if the element is a vertex."""
|
||||
return self in [
|
||||
GraphElemType.DOCUMENT,
|
||||
GraphElemType.CHUNK,
|
||||
GraphElemType.ENTITY,
|
||||
]
|
||||
|
||||
def is_edge(self) -> bool:
|
||||
"""Check if the element is an edge."""
|
||||
return not self.is_vertex()
|
||||
|
||||
|
||||
class Direction(Enum):
|
||||
"""Direction class."""
|
||||
|
||||
@@ -41,7 +70,7 @@ class Elem(ABC):
|
||||
|
||||
def set_prop(self, key: str, value: Any):
|
||||
"""Set a property of ELem."""
|
||||
self._props[key] = value
|
||||
self._props[key] = value # note: always update the value
|
||||
|
||||
def get_prop(self, key: str):
|
||||
"""Get one of the properties of Elem."""
|
||||
@@ -124,6 +153,18 @@ class Edge(Elem):
|
||||
for k, v in props.items():
|
||||
self.set_prop(k, v)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Check if two edges are equal.
|
||||
|
||||
Let's say two edges are equal if they have the same source vertex ID,
|
||||
target vertex ID, and edge label. The properties are not considered.
|
||||
"""
|
||||
return (self.sid, self.tid, self.name) == (other.sid, other.tid, other.name)
|
||||
|
||||
def __hash__(self):
|
||||
"""Return the hash value of the edge."""
|
||||
return hash((self.sid, self.tid, self.name))
|
||||
|
||||
@property
|
||||
def sid(self) -> str:
|
||||
"""Return the source vertex ID of the edge."""
|
||||
@@ -188,11 +229,15 @@ class Graph(ABC):
|
||||
"""Get neighbor edges."""
|
||||
|
||||
@abstractmethod
|
||||
def vertices(self) -> Iterator[Vertex]:
|
||||
def vertices(
|
||||
self, filter_fn: Optional[Callable[[Vertex], bool]] = None
|
||||
) -> Iterator[Vertex]:
|
||||
"""Get vertex iterator."""
|
||||
|
||||
@abstractmethod
|
||||
def edges(self) -> Iterator[Edge]:
|
||||
def edges(
|
||||
self, filter_fn: Optional[Callable[[Edge], bool]] = None
|
||||
) -> Iterator[Edge]:
|
||||
"""Get edge iterator."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -241,7 +286,7 @@ class MemoryGraph(Graph):
|
||||
self._edge_prop_keys = set()
|
||||
self._edge_count = 0
|
||||
|
||||
# init vertices, out edges, in edges index
|
||||
# vertices index, out edges index, in edges index
|
||||
self._vs: Any = defaultdict()
|
||||
self._oes: Any = defaultdict(lambda: defaultdict(set))
|
||||
self._ies: Any = defaultdict(lambda: defaultdict(set))
|
||||
@@ -269,7 +314,7 @@ class MemoryGraph(Graph):
|
||||
# update metadata
|
||||
self._vertex_prop_keys.update(vertex.props.keys())
|
||||
|
||||
def append_edge(self, edge: Edge):
|
||||
def append_edge(self, edge: Edge) -> bool:
|
||||
"""Append an edge if it doesn't exist; requires edge label."""
|
||||
sid = edge.sid
|
||||
tid = edge.tid
|
||||
@@ -290,6 +335,34 @@ class MemoryGraph(Graph):
|
||||
self._edge_count += 1
|
||||
return True
|
||||
|
||||
def upsert_vertex_and_edge(
|
||||
self,
|
||||
src_vid: str,
|
||||
src_name: str,
|
||||
src_props: Dict[str, Any],
|
||||
dst_vid: str,
|
||||
dst_name: str,
|
||||
dst_props: Dict[str, Any],
|
||||
edge_name: str,
|
||||
edge_type: str,
|
||||
):
|
||||
"""Uperst src and dst vertex, and edge."""
|
||||
src_vertex = Vertex(src_vid, src_name, **src_props)
|
||||
dst_vertex = Vertex(dst_vid, dst_name, **dst_props)
|
||||
edge = Edge(src_vid, dst_vid, edge_name, **{"edge_type": edge_type})
|
||||
|
||||
self.upsert_vertex(src_vertex)
|
||||
self.upsert_vertex(dst_vertex)
|
||||
self.append_edge(edge)
|
||||
|
||||
def upsert_graph(self, graph: "MemoryGraph"):
|
||||
"""Upsert a graph."""
|
||||
for vertex in graph.vertices():
|
||||
self.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self.append_edge(edge)
|
||||
|
||||
def has_vertex(self, vid: str) -> bool:
|
||||
"""Retrieve a vertex by ID."""
|
||||
return vid in self._vs
|
||||
@@ -335,13 +408,26 @@ class MemoryGraph(Graph):
|
||||
|
||||
return itertools.islice(es, limit) if limit else es
|
||||
|
||||
def vertices(self) -> Iterator[Vertex]:
|
||||
def vertices(
|
||||
self, filter_fn: Optional[Callable[[Vertex], bool]] = None
|
||||
) -> Iterator[Vertex]:
|
||||
"""Return vertices."""
|
||||
return iter(self._vs.values())
|
||||
# Get all vertices in the graph
|
||||
all_vertices = self._vs.values()
|
||||
|
||||
def edges(self) -> Iterator[Edge]:
|
||||
return all_vertices if filter_fn is None else filter(filter_fn, all_vertices)
|
||||
|
||||
def edges(
|
||||
self, filter_fn: Optional[Callable[[Edge], bool]] = None
|
||||
) -> Iterator[Edge]:
|
||||
"""Return edges."""
|
||||
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
|
||||
# Get all edges in the graph
|
||||
all_edges = (e for nbs in self._oes.values() for es in nbs.values() for e in es)
|
||||
|
||||
if filter_fn is None:
|
||||
return all_edges
|
||||
else:
|
||||
return filter(filter_fn, all_edges)
|
||||
|
||||
def del_vertices(self, *vids: str):
|
||||
"""Delete specified vertices."""
|
||||
@@ -353,7 +439,7 @@ class MemoryGraph(Graph):
|
||||
"""Delete edges."""
|
||||
old_edge_cnt = len(self._oes[sid][tid])
|
||||
|
||||
def remove_matches(es):
|
||||
def remove_matches(es: Set[Edge]):
|
||||
return set(
|
||||
filter(
|
||||
lambda e: not (
|
||||
@@ -452,7 +538,7 @@ class MemoryGraph(Graph):
|
||||
]
|
||||
}
|
||||
|
||||
def format(self) -> str:
|
||||
def format(self, entities_only: Optional[bool] = False) -> str:
|
||||
"""Format graph to string."""
|
||||
vs_str = "\n".join(v.format() for v in self.vertices())
|
||||
es_str = "\n".join(
|
||||
@@ -461,11 +547,14 @@ class MemoryGraph(Graph):
|
||||
f"{self.get_vertex(e.tid).format(concise=True)}"
|
||||
for e in self.edges()
|
||||
)
|
||||
return (
|
||||
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
|
||||
if (vs_str or es_str)
|
||||
else ""
|
||||
)
|
||||
if entities_only:
|
||||
return f"Entities:\n{vs_str}" if vs_str else ""
|
||||
else:
|
||||
return (
|
||||
f"Entities:\n{vs_str}\n\nRelationships:\n{es_str}"
|
||||
if (vs_str or es_str)
|
||||
else ""
|
||||
)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
|
@@ -1,11 +1,10 @@
|
||||
"""Memory graph store."""
|
||||
import json
|
||||
|
||||
import logging
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,77 +27,8 @@ class MemoryGraphStore(GraphStoreBase):
|
||||
"""Get the graph store config."""
|
||||
return self._graph_store_config
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
raise NotImplementedError("Memory graph store does not have edge type")
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
raise NotImplementedError("Memory graph store does not have vertex type")
|
||||
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Insert a triplet into the graph."""
|
||||
self._graph.append_edge(Edge(sub, obj, rel))
|
||||
|
||||
def insert_graph(self, graph: Graph):
|
||||
"""Add graph."""
|
||||
for vertex in graph.vertices():
|
||||
self._graph.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self._graph.append_edge(edge)
|
||||
|
||||
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
|
||||
"""Retrieve triplets originating from a subject."""
|
||||
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
|
||||
return [(e.name, e.tid) for e in subgraph.edges()]
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str):
|
||||
"""Delete a specific triplet from the graph."""
|
||||
self._graph.del_edges(sub, obj, rel)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate graph."""
|
||||
self._graph.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop graph."""
|
||||
self._graph = None
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Return the graph schema as a JSON string."""
|
||||
return json.dumps(self._graph.schema())
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Return self."""
|
||||
if not limit:
|
||||
return self._graph
|
||||
|
||||
subgraph = MemoryGraph()
|
||||
for count, edge in enumerate(self._graph.edges()):
|
||||
if count >= limit:
|
||||
break
|
||||
subgraph.upsert_vertex(self._graph.get_vertex(edge.sid))
|
||||
subgraph.upsert_vertex(self._graph.get_vertex(edge.tid))
|
||||
subgraph.append_edge(edge)
|
||||
count += 1
|
||||
return subgraph
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return self._graph.search(subs, direct, depth, fan, limit)
|
||||
|
||||
def query(self, query: str, **args) -> Graph:
|
||||
"""Execute a query on graph."""
|
||||
raise NotImplementedError("Query memory graph not allowed")
|
||||
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute stream query."""
|
||||
raise NotImplementedError("Stream query memory graph not allowed")
|
||||
def _escape_quotes(self, text: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
raise NotImplementedError(
|
||||
"_escape_quotes is not implemented by MemoryGraphStore"
|
||||
)
|
||||
|
@@ -1,14 +1,15 @@
|
||||
"""TuGraph store."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Generator, Iterator, List, Optional, Tuple
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.datasource.conn_tugraph import TuGraphConnector
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph, Vertex
|
||||
from dbgpt.storage.graph_store.graph import GraphElemType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,12 +36,28 @@ class TuGraphStoreConfig(GraphStoreConfig):
|
||||
description="login password",
|
||||
)
|
||||
vertex_type: str = Field(
|
||||
default="entity",
|
||||
description="The type of vertex, `entity` by default.",
|
||||
default=GraphElemType.ENTITY.value,
|
||||
description="The type of entity vertex, `entity` by default.",
|
||||
)
|
||||
document_type: str = Field(
|
||||
default=GraphElemType.DOCUMENT.value,
|
||||
description="The type of document vertex, `document` by default.",
|
||||
)
|
||||
chunk_type: str = Field(
|
||||
default=GraphElemType.CHUNK.value,
|
||||
description="The type of chunk vertex, `relation` by default.",
|
||||
)
|
||||
edge_type: str = Field(
|
||||
default="relation",
|
||||
description="The type of edge, `relation` by default.",
|
||||
default=GraphElemType.RELATION.value,
|
||||
description="The type of relation edge, `relation` by default.",
|
||||
)
|
||||
include_type: str = Field(
|
||||
default=GraphElemType.INCLUDE.value,
|
||||
description="The type of include edge, `include` by default.",
|
||||
)
|
||||
next_type: str = Field(
|
||||
default=GraphElemType.NEXT.value,
|
||||
description="The type of next edge, `next` by default.",
|
||||
)
|
||||
plugin_names: List[str] = Field(
|
||||
default=["leiden"],
|
||||
@@ -62,17 +79,24 @@ class TuGraphStore(GraphStoreBase):
|
||||
self._port = int(os.getenv("TUGRAPH_PORT", config.port))
|
||||
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
|
||||
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
|
||||
self._summary_enabled = (
|
||||
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
|
||||
or config.summary_enabled
|
||||
self._enable_summary = (
|
||||
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
|
||||
or config.enable_summary
|
||||
)
|
||||
self._enable_document_graph = (
|
||||
os.getenv("ENABLE_DOCUMENT_GRAPH", "").lower() == "true"
|
||||
or config.enable_document_graph
|
||||
)
|
||||
self._enable_triplet_graph = (
|
||||
os.getenv("ENABLE_TRIPLET_GRAPH", "").lower() == "true"
|
||||
or config.enable_triplet_graph
|
||||
)
|
||||
self._plugin_names = (
|
||||
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
|
||||
or config.plugin_names
|
||||
)
|
||||
|
||||
self._graph_name = config.name
|
||||
self._vertex_type = os.getenv("TUGRAPH_VERTEX_TYPE", config.vertex_type)
|
||||
self._edge_type = os.getenv("TUGRAPH_EDGE_TYPE", config.edge_type)
|
||||
|
||||
self.conn = TuGraphConnector.from_uri_db(
|
||||
host=self._host,
|
||||
@@ -82,34 +106,29 @@ class TuGraphStore(GraphStoreBase):
|
||||
db_name=config.name,
|
||||
)
|
||||
|
||||
self._create_graph(config.name)
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
return self._vertex_type
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
return self._edge_type
|
||||
|
||||
def _create_graph(self, graph_name: str):
|
||||
self.conn.create_graph(graph_name=graph_name)
|
||||
self._create_schema()
|
||||
if self._summary_enabled:
|
||||
self._upload_plugin()
|
||||
|
||||
def _check_label(self, elem_type: str):
|
||||
result = self.conn.get_table_names()
|
||||
if elem_type == "vertex":
|
||||
return self._vertex_type in result["vertex_tables"]
|
||||
if elem_type == "edge":
|
||||
return self._edge_type in result["edge_tables"]
|
||||
def get_config(self) -> TuGraphStoreConfig:
|
||||
"""Get the TuGraph store config."""
|
||||
return self._config
|
||||
|
||||
def _add_vertex_index(self, field_name):
|
||||
gql = f"CALL db.addIndex('{self._vertex_type}', '{field_name}', false)"
|
||||
"""Add an index to the vertex table."""
|
||||
# TODO: Not used in the current implementation.
|
||||
gql = f"CALL db.addIndex('{GraphElemType.ENTITY.value}', '{field_name}', false)"
|
||||
self.conn.run(gql)
|
||||
|
||||
def _upload_plugin(self):
|
||||
"""Upload missing plugins to the TuGraph database.
|
||||
|
||||
This method checks for the presence of required plugins in the database and
|
||||
uploads any missing plugins. It performs the following steps:
|
||||
1. Lists existing plugins in the database.
|
||||
2. Identifies missing plugins by comparing with the required plugin list.
|
||||
3. For each missing plugin, reads its binary content, encodes it, and uploads to
|
||||
the database.
|
||||
|
||||
The method uses the 'leiden' plugin as an example, but can be extended for other
|
||||
plugins.
|
||||
"""
|
||||
gql = "CALL db.plugin.listPlugin('CPP','v1')"
|
||||
result = self.conn.run(gql)
|
||||
result_names = [
|
||||
@@ -122,8 +141,8 @@ class TuGraphStore(GraphStoreBase):
|
||||
if len(missing_plugins):
|
||||
for name in missing_plugins:
|
||||
try:
|
||||
from dbgpt_tugraph_plugins import ( # type: ignore # noqa
|
||||
get_plugin_binary_path,
|
||||
from dbgpt_tugraph_plugins import (
|
||||
get_plugin_binary_path, # type:ignore[import-untyped]
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
@@ -136,375 +155,12 @@ class TuGraphStore(GraphStoreBase):
|
||||
content = f.read()
|
||||
content = base64.b64encode(content).decode()
|
||||
gql = (
|
||||
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', "
|
||||
"'SO', '{name} Plugin', false, 'v1')"
|
||||
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', 'SO', "
|
||||
f"'{name} Plugin', false, 'v1')"
|
||||
)
|
||||
self.conn.run(gql)
|
||||
|
||||
def _create_schema(self):
|
||||
if not self._check_label("vertex"):
|
||||
if self._summary_enabled:
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._vertex_type}', "
|
||||
f"'id', ['id',string,false],"
|
||||
f"['name',string,false],"
|
||||
f"['_document_id',string,true],"
|
||||
f"['_chunk_id',string,true],"
|
||||
f"['_community_id',string,true],"
|
||||
f"['description',string,true])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
self._add_vertex_index("_community_id")
|
||||
else:
|
||||
create_vertex_gql = (
|
||||
f"CALL db.createLabel("
|
||||
f"'vertex', '{self._vertex_type}', "
|
||||
f"'id', ['id',string,false],"
|
||||
f"['name',string,false])"
|
||||
)
|
||||
self.conn.run(create_vertex_gql)
|
||||
|
||||
if not self._check_label("edge"):
|
||||
create_edge_gql = f"""CALL db.createLabel(
|
||||
'edge', '{self._edge_type}',
|
||||
'[["{self._vertex_type}",
|
||||
"{self._vertex_type}"]]',
|
||||
["id",STRING,false],
|
||||
["name",STRING,false])"""
|
||||
if self._summary_enabled:
|
||||
create_edge_gql = f"""CALL db.createLabel(
|
||||
'edge', '{self._edge_type}',
|
||||
'[["{self._vertex_type}",
|
||||
"{self._vertex_type}"]]',
|
||||
["id",STRING,false],
|
||||
["name",STRING,false],
|
||||
["description",STRING,true])"""
|
||||
self.conn.run(create_edge_gql)
|
||||
|
||||
def _format_query_data(self, data, white_prop_list: List[str]):
|
||||
nodes_list = []
|
||||
rels_list: List[Any] = []
|
||||
_white_list = white_prop_list
|
||||
from neo4j import graph
|
||||
|
||||
def get_filtered_properties(properties, white_list):
|
||||
return {
|
||||
key: value
|
||||
for key, value in properties.items()
|
||||
if (not key.startswith("_") and key not in ["id", "name"])
|
||||
or key in white_list
|
||||
}
|
||||
|
||||
def process_node(node: graph.Node):
|
||||
node_id = node._properties.get("id")
|
||||
node_name = node._properties.get("name")
|
||||
node_properties = get_filtered_properties(node._properties, _white_list)
|
||||
nodes_list.append(
|
||||
{"id": node_id, "name": node_name, "properties": node_properties}
|
||||
)
|
||||
|
||||
def process_relationship(rel: graph.Relationship):
|
||||
name = rel._properties.get("name", "")
|
||||
rel_nodes = rel.nodes
|
||||
src_id = rel_nodes[0]._properties.get("id")
|
||||
dst_id = rel_nodes[1]._properties.get("id")
|
||||
for node in rel_nodes:
|
||||
process_node(node)
|
||||
edge_properties = get_filtered_properties(rel._properties, _white_list)
|
||||
if not any(
|
||||
existing_edge.get("name") == name
|
||||
and existing_edge.get("src_id") == src_id
|
||||
and existing_edge.get("dst_id") == dst_id
|
||||
for existing_edge in rels_list
|
||||
):
|
||||
rels_list.append(
|
||||
{
|
||||
"src_id": src_id,
|
||||
"dst_id": dst_id,
|
||||
"name": name,
|
||||
"properties": edge_properties,
|
||||
}
|
||||
)
|
||||
|
||||
def process_path(path: graph.Path):
|
||||
for rel in path.relationships:
|
||||
process_relationship(rel)
|
||||
|
||||
def process_other(value):
|
||||
if not any(
|
||||
existing_node.get("id") == "json_node" for existing_node in nodes_list
|
||||
):
|
||||
nodes_list.append(
|
||||
{
|
||||
"id": "json_node",
|
||||
"name": "json_node",
|
||||
"properties": {"description": value},
|
||||
}
|
||||
)
|
||||
|
||||
for record in data:
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
process_node(value)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
process_relationship(value)
|
||||
elif isinstance(value, graph.Path):
|
||||
process_path(value)
|
||||
else:
|
||||
process_other(value)
|
||||
nodes = [
|
||||
Vertex(node["id"], node["name"], **node["properties"])
|
||||
for node in nodes_list
|
||||
]
|
||||
rels = [
|
||||
Edge(edge["src_id"], edge["dst_id"], edge["name"], **edge["properties"])
|
||||
for edge in rels_list
|
||||
]
|
||||
return {"nodes": nodes, "edges": rels}
|
||||
|
||||
def get_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self._config
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
query = (
|
||||
f"MATCH (n1:{self._vertex_type})-[r]->(n2:{self._vertex_type}) "
|
||||
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
|
||||
)
|
||||
data = self.conn.run(query)
|
||||
return [(record["rel"], record["obj"]) for record in data]
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
|
||||
def escape_quotes(value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
return value.replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
subj_escaped = escape_quotes(subj)
|
||||
rel_escaped = escape_quotes(rel)
|
||||
obj_escaped = escape_quotes(obj)
|
||||
|
||||
node_query = f"""CALL db.upsertVertex(
|
||||
'{self._vertex_type}',
|
||||
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
|
||||
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
'{self._edge_type}',
|
||||
{{type:"{self._vertex_type}",key:"sid"}},
|
||||
{{type:"{self._vertex_type}", key:"tid"}},
|
||||
[{{sid:"{subj_escaped}",
|
||||
tid: "{obj_escaped}",
|
||||
id:"{rel_escaped}",
|
||||
name: "{rel_escaped}"}}])"""
|
||||
self.conn.run(query=node_query)
|
||||
self.conn.run(query=edge_query)
|
||||
|
||||
def insert_graph(self, graph: Graph) -> None:
|
||||
"""Add graph."""
|
||||
|
||||
def escape_quotes(value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
if value is not None:
|
||||
return value.replace("'", "").replace('"', "")
|
||||
|
||||
nodes: Iterator[Vertex] = graph.vertices()
|
||||
edges: Iterator[Edge] = graph.edges()
|
||||
node_list = []
|
||||
edge_list = []
|
||||
|
||||
def parser(node_list):
|
||||
formatted_nodes = [
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
|
||||
for k, v in node.items()
|
||||
)
|
||||
+ "}"
|
||||
for node in node_list
|
||||
]
|
||||
return f"""{', '.join(formatted_nodes)}"""
|
||||
|
||||
for node in nodes:
|
||||
node_list.append(
|
||||
{
|
||||
"id": escape_quotes(node.vid),
|
||||
"name": escape_quotes(node.name),
|
||||
"description": escape_quotes(node.get_prop("description")) or "",
|
||||
"_document_id": "0",
|
||||
"_chunk_id": "0",
|
||||
"_community_id": "0",
|
||||
}
|
||||
)
|
||||
node_query = (
|
||||
f"""CALL db.upsertVertex("{self._vertex_type}", [{parser(node_list)}])"""
|
||||
)
|
||||
for edge in edges:
|
||||
edge_list.append(
|
||||
{
|
||||
"sid": escape_quotes(edge.sid),
|
||||
"tid": escape_quotes(edge.tid),
|
||||
"id": escape_quotes(edge.name),
|
||||
"name": escape_quotes(edge.name),
|
||||
"description": escape_quotes(edge.get_prop("description")),
|
||||
}
|
||||
)
|
||||
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
"{self._edge_type}",
|
||||
{{type:"{self._vertex_type}", key:"sid"}},
|
||||
{{type:"{self._vertex_type}", key:"tid"}},
|
||||
[{parser(edge_list)}])"""
|
||||
self.conn.run(query=node_query)
|
||||
self.conn.run(query=edge_query)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
gql = "MATCH (n) DELETE n"
|
||||
self.conn.run(gql)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self.conn.delete_graph(self._graph_name)
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
del_query = (
|
||||
f"MATCH (n1:{self._vertex_type} {{id:'{sub}'}})"
|
||||
f"-[r:{self._edge_type} {{id:'{rel}'}}]->"
|
||||
f"(n2:{self._vertex_type} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
)
|
||||
self.conn.run(query=del_query)
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
query = "CALL dbms.graph.getGraphSchema()"
|
||||
data = self.conn.run(query=query)
|
||||
schema = data[0]["schema"]
|
||||
return schema
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
raise Exception("limit must be set")
|
||||
graph_result = self.query(
|
||||
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
|
||||
white_list=["_community_id"],
|
||||
)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in graph_result.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in graph_result.edges():
|
||||
all_graph.append_edge(edge)
|
||||
return all_graph
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
fan: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Graph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
if not subs:
|
||||
return MemoryGraph()
|
||||
|
||||
if fan is not None:
|
||||
raise ValueError("Fan functionality is not supported at this time.")
|
||||
else:
|
||||
depth_string = f"1..{depth}"
|
||||
if depth is None:
|
||||
depth_string = ".."
|
||||
|
||||
limit_string = f"LIMIT {limit}"
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{self._edge_type}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{self._edge_type}*{depth_string}]-"
|
||||
else:
|
||||
rel = f"-[r:{self._edge_type}*{depth_string}]-"
|
||||
query = (
|
||||
f"MATCH p=(n:{self._vertex_type})"
|
||||
f"{rel}(m:{self._vertex_type}) "
|
||||
f"WHERE n.id IN {subs} RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query)
|
||||
|
||||
def query(self, query: str, **args) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
result = self.conn.run(query=query)
|
||||
white_list = args.get("white_list", [])
|
||||
graph = self._format_query_data(result, white_list)
|
||||
mg = MemoryGraph()
|
||||
for vertex in graph["nodes"]:
|
||||
mg.upsert_vertex(vertex)
|
||||
for edge in graph["edges"]:
|
||||
mg.append_edge(edge)
|
||||
return mg
|
||||
|
||||
def stream_query(self, query: str) -> Generator[Graph, None, None]:
|
||||
"""Execute a stream query."""
|
||||
from neo4j import graph
|
||||
|
||||
for record in self.conn.run_stream(query):
|
||||
mg = MemoryGraph()
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
description = value._properties["description"]
|
||||
vertex = Vertex(node_id, name=node_id, description=description)
|
||||
mg.upsert_vertex(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
rel_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
src_id = rel_nodes[0]._properties["id"]
|
||||
dst_id = rel_nodes[1]._properties["id"]
|
||||
description = value._properties["description"]
|
||||
edge = Edge(src_id, dst_id, name=prop_id, description=description)
|
||||
mg.append_edge(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
nodes = list(record["p"].nodes)
|
||||
rels = list(record["p"].relationships)
|
||||
formatted_path = []
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
if i < len(rels):
|
||||
formatted_path.append(
|
||||
{
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
}
|
||||
)
|
||||
for i in range(0, len(formatted_path), 2):
|
||||
mg.upsert_vertex(
|
||||
Vertex(
|
||||
formatted_path[i]["id"],
|
||||
name=formatted_path[i]["id"],
|
||||
description=formatted_path[i]["description"],
|
||||
)
|
||||
)
|
||||
if i + 2 < len(formatted_path):
|
||||
mg.append_edge(
|
||||
Edge(
|
||||
formatted_path[i]["id"],
|
||||
formatted_path[i + 2]["id"],
|
||||
name=formatted_path[i + 1]["id"],
|
||||
description=formatted_path[i + 1]["description"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
vertex = Vertex("json_node", name="json_node", description=value)
|
||||
mg.upsert_vertex(vertex)
|
||||
yield mg
|
||||
def _escape_quotes(self, value: str) -> str:
|
||||
"""Escape single and double quotes in a string for queries."""
|
||||
if value is not None:
|
||||
return value.replace("'", "").replace('"', "")
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Knowledge graph base class."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
@@ -27,6 +28,6 @@ class KnowledgeGraphBase(IndexStoreBase, ABC):
|
||||
def query_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get graph data."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: str) -> List[str]:
|
||||
"""Delete document by ids."""
|
||||
raise Exception("Delete document not supported by knowledge graph")
|
||||
|
@@ -1,11 +1,19 @@
|
||||
"""Define Classes about Community."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import AsyncGenerator, Iterator, List, Optional
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.graph import Graph
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +32,7 @@ class CommunityTree:
|
||||
"""Represents a community tree."""
|
||||
|
||||
|
||||
class CommunityStoreAdapter(ABC):
|
||||
class GraphStoreAdapter(ABC):
|
||||
"""Community Store Adapter."""
|
||||
|
||||
def __init__(self, graph_store: GraphStoreBase):
|
||||
@@ -44,6 +52,113 @@ class CommunityStoreAdapter(ABC):
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
||||
@abstractmethod
|
||||
def get_graph_config(self):
|
||||
"""Get config."""
|
||||
|
||||
@abstractmethod
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get vertex type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get edge type."""
|
||||
|
||||
@abstractmethod
|
||||
def get_triplets(self, sub: str) -> List[tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
|
||||
@abstractmethod
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get document vertex."""
|
||||
|
||||
@abstractmethod
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get schema."""
|
||||
|
||||
@abstractmethod
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entity."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
):
|
||||
"""Upsert edge."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_chunks(self, chunk: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunk."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
|
||||
@abstractmethod
|
||||
def insert_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Insert triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def upsert_graph(self, graph: Graph) -> None:
|
||||
"""Insert graph."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_document(self, chunk_id: str) -> None:
|
||||
"""Delete document in graph store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self) -> None:
|
||||
"""Drop graph."""
|
||||
|
||||
@abstractmethod
|
||||
def create_graph(self, graph_name: str) -> None:
|
||||
"""Create graph."""
|
||||
|
||||
@abstractmethod
|
||||
def create_graph_label(self) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def truncate(self) -> None:
|
||||
"""Truncate graph."""
|
||||
|
||||
@abstractmethod
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph."""
|
||||
|
||||
@abstractmethod
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
|
||||
@abstractmethod
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
|
||||
|
||||
class CommunityMetastore(ABC):
|
||||
"""Community metastore class."""
|
||||
|
@@ -4,10 +4,7 @@ import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.community_metastore import (
|
||||
BuiltinCommunityMetastore,
|
||||
)
|
||||
@@ -21,23 +18,23 @@ class CommunityStore:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
community_store_adapter: CommunityStoreAdapter,
|
||||
graph_store_adapter: GraphStoreAdapter,
|
||||
community_summarizer: CommunitySummarizer,
|
||||
vector_store: VectorStoreBase,
|
||||
):
|
||||
"""Initialize the CommunityStore class."""
|
||||
self._community_store_adapter = community_store_adapter
|
||||
self._graph_store_adapter = graph_store_adapter
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
"""Discover communities."""
|
||||
community_ids = await self._community_store_adapter.discover_communities()
|
||||
community_ids = await self._graph_store_adapter.discover_communities()
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await self._community_store_adapter.get_community(community_id)
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
@@ -65,7 +62,7 @@ class CommunityStore:
|
||||
self._community_summarizer.truncate()
|
||||
|
||||
logger.info("Truncate graph")
|
||||
self._community_store_adapter.graph_store.truncate()
|
||||
self._graph_store_adapter.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop community store."""
|
||||
@@ -76,4 +73,4 @@ class CommunityStore:
|
||||
self._community_summarizer.drop()
|
||||
|
||||
logger.info("Remove graph")
|
||||
self._community_store_adapter.graph_store.drop()
|
||||
self._graph_store_adapter.drop()
|
||||
|
@@ -1,28 +1,29 @@
|
||||
"""CommunityStoreAdapter factory."""
|
||||
"""GraphStoreAdapter factory."""
|
||||
|
||||
import logging
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
|
||||
TuGraphCommunityStoreAdapter,
|
||||
from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
|
||||
TuGraphStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityStoreAdapterFactory:
|
||||
class GraphStoreAdapterFactory:
|
||||
"""Factory for community store adapter."""
|
||||
|
||||
@staticmethod
|
||||
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
|
||||
"""Create a CommunityStoreAdapter instance.
|
||||
def create(graph_store: GraphStoreBase) -> GraphStoreAdapter:
|
||||
"""Create a GraphStoreAdapter instance.
|
||||
|
||||
Args:
|
||||
- graph_store_type: graph store type Memory, TuGraph, Neo4j
|
||||
"""
|
||||
if isinstance(graph_store, TuGraphStore):
|
||||
return TuGraphCommunityStoreAdapter(graph_store)
|
||||
return TuGraphStoreAdapter(graph_store)
|
||||
else:
|
||||
raise Exception(
|
||||
"create community store adapter for %s failed",
|
||||
|
@@ -0,0 +1,186 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, Iterator, List, Optional, Tuple
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
from dbgpt.storage.graph_store.memgraph_store import (
|
||||
MemoryGraphStore,
|
||||
MemoryGraphStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemGraphStoreAdapter(GraphStoreAdapter):
|
||||
"""MemGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
def __init__(self, enable_summary: bool = False):
|
||||
"""Initialize MemGraph Community Store Adapter."""
|
||||
self._graph_store = MemoryGraphStore(MemoryGraphStoreConfig())
|
||||
self._enable_summary = enable_summary
|
||||
|
||||
super().__init__(self._graph_store)
|
||||
|
||||
# Create the graph
|
||||
self.create_graph(self._graph_store.get_config().name)
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
pass
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
pass
|
||||
|
||||
def get_graph_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self._graph_store.get_config()
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
# raise NotImplementedError("Memory graph store does not have vertex type")
|
||||
return ""
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
# raise NotImplementedError("Memory graph store does not have edge type")
|
||||
return ""
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
subgraph = self.explore([subj], direct=Direction.OUT, depth=1)
|
||||
return [(e.name, e.tid) for e in subgraph.edges()]
|
||||
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get the document vertex in the graph."""
|
||||
raise NotImplementedError("Memory graph store does not have document vertex")
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
return json.dumps(self._graph_store._graph.schema())
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
return self._graph_store._graph
|
||||
|
||||
subgraph = MemoryGraph()
|
||||
for count, edge in enumerate(self._graph_store._graph.edges()):
|
||||
if count >= limit:
|
||||
break
|
||||
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.sid))
|
||||
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.tid))
|
||||
subgraph.append_edge(edge)
|
||||
count += 1
|
||||
return subgraph
|
||||
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entities."""
|
||||
pass
|
||||
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
) -> None:
|
||||
"""Upsert edges."""
|
||||
pass
|
||||
|
||||
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunks."""
|
||||
pass
|
||||
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
pass
|
||||
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
self._graph_store._graph.append_edge(Edge(subj, obj, rel))
|
||||
|
||||
def upsert_graph(self, graph: Graph) -> None:
|
||||
"""Add graph to the graph store.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph to be added.
|
||||
"""
|
||||
for vertex in graph.vertices():
|
||||
self._graph_store._graph.upsert_vertex(vertex)
|
||||
|
||||
for edge in graph.edges():
|
||||
self._graph_store._graph.append_edge(edge)
|
||||
|
||||
def delete_document(self, chunk_ids: str) -> None:
|
||||
"""Delete document in the graph."""
|
||||
pass
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
self._graph_store._graph.del_edges(sub, obj, rel)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self._graph_store._graph = None
|
||||
|
||||
def create_graph(self, graph_name: str):
|
||||
"""Create a graph."""
|
||||
pass
|
||||
|
||||
def create_graph_label(
|
||||
self,
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
self._graph_store._graph.truncate()
|
||||
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph.
|
||||
|
||||
Args:
|
||||
graph_elem_type (GraphElemType): The type of the graph element.
|
||||
|
||||
Returns:
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
pass
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: int | None = None,
|
||||
fan: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph."""
|
||||
pass
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
pass
|
@@ -1,52 +0,0 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self._graph_store.query(
|
||||
"CALL db.plugin.callPlugin"
|
||||
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
|
||||
f"[r:{self._graph_store.get_edge_type()}]-"
|
||||
f"(m:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self._graph_store.aquery(query)
|
||||
all_edge_graph = self._graph_store.aquery(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
808
dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Normal file
808
dbgpt/storage/knowledge_graph/community/tugraph_store_adapter.py
Normal file
@@ -0,0 +1,808 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from dbgpt.storage.graph_store.graph import (
|
||||
Direction,
|
||||
Edge,
|
||||
Graph,
|
||||
GraphElemType,
|
||||
MemoryGraph,
|
||||
Vertex,
|
||||
)
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_QUERY_LIMIT = 1000
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
def __init__(self, graph_store: TuGraphStore):
|
||||
"""Initialize TuGraph Community Store Adapter."""
|
||||
super().__init__(graph_store)
|
||||
|
||||
# Create the graph
|
||||
self.create_graph(self.graph_store.get_config().name)
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self.query(
|
||||
"CALL db.plugin.callPlugin('CPP',"
|
||||
"'leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self.get_vertex_type()}) WHERE n._community_id = "
|
||||
f"'{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self.get_vertex_type()})-"
|
||||
f"[r:{self.get_edge_type()}]-"
|
||||
f"(m:{self.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self.query(query)
|
||||
all_edge_graph = self.query(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
||||
|
||||
@property
|
||||
def graph_store(self) -> TuGraphStore:
|
||||
"""Get the graph store."""
|
||||
return self._graph_store
|
||||
|
||||
def get_graph_config(self):
|
||||
"""Get the graph store config."""
|
||||
return self.graph_store.get_config()
|
||||
|
||||
def get_vertex_type(self) -> str:
|
||||
"""Get the vertex type."""
|
||||
return GraphElemType.ENTITY.value
|
||||
|
||||
def get_edge_type(self) -> str:
|
||||
"""Get the edge type."""
|
||||
return GraphElemType.RELATION.value
|
||||
|
||||
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
|
||||
"""Get triplets."""
|
||||
triplet_query = (
|
||||
f"MATCH (n1:{GraphElemType.ENTITY.value})-[r]->(n2:"
|
||||
f"{GraphElemType.ENTITY.value}) "
|
||||
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
|
||||
)
|
||||
data = self.graph_store.conn.run(triplet_query)
|
||||
return [(record["rel"], record["obj"]) for record in data]
|
||||
|
||||
def get_document_vertex(self, doc_name: str) -> Vertex:
|
||||
"""Get the document vertex in the graph."""
|
||||
gql = f"""MATCH (n) WHERE n.id = {doc_name} RETURN n"""
|
||||
graph = self.query(gql)
|
||||
vertex = graph.get_vertex(doc_name)
|
||||
return vertex
|
||||
|
||||
def get_schema(self, refresh: bool = False) -> str:
|
||||
"""Get the schema of the graph store."""
|
||||
query = "CALL dbms.graph.getGraphSchema()"
|
||||
data = self.graph_store.conn.run(query=query)
|
||||
schema = data[0]["schema"]
|
||||
return schema
|
||||
|
||||
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Get full graph."""
|
||||
if not limit:
|
||||
limit = self.MAX_QUERY_LIMIT
|
||||
if limit <= 0:
|
||||
raise ValueError("Limit must be greater than 0.")
|
||||
graph_result = self.query(
|
||||
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
|
||||
white_list=["_community_id"],
|
||||
)
|
||||
full_graph = MemoryGraph()
|
||||
for vertex in graph_result.vertices():
|
||||
full_graph.upsert_vertex(vertex)
|
||||
for edge in graph_result.edges():
|
||||
full_graph.append_edge(edge)
|
||||
return full_graph
|
||||
|
||||
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
|
||||
"""Upsert entities."""
|
||||
entity_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(entity.vid),
|
||||
"name": self.graph_store._escape_quotes(entity.name),
|
||||
"description": self.graph_store._escape_quotes(
|
||||
entity.get_prop("description")
|
||||
)
|
||||
or "",
|
||||
"_document_id": "0",
|
||||
"_chunk_id": "0",
|
||||
"_community_id": "0",
|
||||
}
|
||||
for entity in entities
|
||||
]
|
||||
entity_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.ENTITY.value}", '
|
||||
f"[{self._parser(entity_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=entity_query)
|
||||
|
||||
def upsert_edge(
|
||||
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
|
||||
) -> None:
|
||||
"""Upsert edges."""
|
||||
edge_list = [
|
||||
{
|
||||
"sid": self.graph_store._escape_quotes(edge.sid),
|
||||
"tid": self.graph_store._escape_quotes(edge.tid),
|
||||
"id": self.graph_store._escape_quotes(edge.name),
|
||||
"name": self.graph_store._escape_quotes(edge.name),
|
||||
"description": self.graph_store._escape_quotes(
|
||||
edge.get_prop("description")
|
||||
)
|
||||
or "",
|
||||
"_chunk_id": self.graph_store._escape_quotes(edge.get_prop("_chunk_id"))
|
||||
or "",
|
||||
}
|
||||
for edge in edges
|
||||
]
|
||||
relation_query = f"""CALL db.upsertEdge("{edge_type}",
|
||||
{{type:"{src_type}", key:"sid"}},
|
||||
{{type:"{dst_type}", key:"tid"}},
|
||||
[{self._parser(edge_list)}])"""
|
||||
self.graph_store.conn.run(query=relation_query)
|
||||
|
||||
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
|
||||
"""Upsert chunks."""
|
||||
chunk_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(chunk.vid),
|
||||
"name": self.graph_store._escape_quotes(chunk.name),
|
||||
"content": self.graph_store._escape_quotes(chunk.get_prop("content")),
|
||||
}
|
||||
for chunk in chunks
|
||||
]
|
||||
chunk_query = (
|
||||
f"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.CHUNK.value}", '
|
||||
f"[{self._parser(chunk_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=chunk_query)
|
||||
|
||||
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
|
||||
"""Upsert documents."""
|
||||
document_list = [
|
||||
{
|
||||
"id": self.graph_store._escape_quotes(document.vid),
|
||||
"name": self.graph_store._escape_quotes(document.name),
|
||||
"content": self.graph_store._escape_quotes(document.get_prop("content"))
|
||||
or "",
|
||||
}
|
||||
for document in documents
|
||||
]
|
||||
document_query = (
|
||||
"CALL db.upsertVertex("
|
||||
f'"{GraphElemType.DOCUMENT.value}", '
|
||||
f"[{self._parser(document_list)}])"
|
||||
)
|
||||
self.graph_store.conn.run(query=document_query)
|
||||
|
||||
def upsert_relations(self, relations: Iterator[Edge]) -> None:
|
||||
"""Upsert relations."""
|
||||
pass
|
||||
|
||||
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
|
||||
"""Add triplet."""
|
||||
subj_escaped = subj.replace("'", "\\'").replace('"', '\\"')
|
||||
rel_escaped = rel.replace("'", "\\'").replace('"', '\\"')
|
||||
obj_escaped = obj.replace("'", "\\'").replace('"', '\\"')
|
||||
|
||||
vertex_query = f"""CALL db.upsertVertex(
|
||||
'{GraphElemType.ENTITY.value}',
|
||||
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
|
||||
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
|
||||
edge_query = f"""CALL db.upsertEdge(
|
||||
'{GraphElemType.RELATION.value}',
|
||||
{{type:"{GraphElemType.ENTITY.value}",key:"sid"}},
|
||||
{{type:"{GraphElemType.ENTITY.value}", key:"tid"}},
|
||||
[{{sid:"{subj_escaped}",
|
||||
tid: "{obj_escaped}",
|
||||
id:"{rel_escaped}",
|
||||
name: "{rel_escaped}"}}])"""
|
||||
|
||||
self.graph_store.conn.run(query=vertex_query)
|
||||
self.graph_store.conn.run(query=edge_query)
|
||||
|
||||
def upsert_graph(self, graph: MemoryGraph) -> None:
|
||||
"""Add graph to the graph store.
|
||||
|
||||
Args:
|
||||
graph (Graph): The graph to be added.
|
||||
"""
|
||||
# Get the iterators of all the vertices and the edges from the graph
|
||||
documents: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type")
|
||||
== GraphElemType.DOCUMENT.value
|
||||
)
|
||||
chunks: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.CHUNK.value
|
||||
)
|
||||
entities: Iterator[Vertex] = graph.vertices(
|
||||
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.ENTITY.value
|
||||
)
|
||||
doc_include_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.DOCUMENT_INCLUDE_CHUNK.value
|
||||
)
|
||||
chunk_include_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_INCLUDE_CHUNK.value
|
||||
)
|
||||
chunk_include_entity: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_INCLUDE_ENTITY.value
|
||||
)
|
||||
chunk_next_chunk: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type")
|
||||
== GraphElemType.CHUNK_NEXT_CHUNK.value
|
||||
)
|
||||
relation: Iterator[Edge] = graph.edges(
|
||||
filter_fn=lambda x: x.get_prop("edge_type") == GraphElemType.RELATION.value
|
||||
)
|
||||
|
||||
# Upsert the vertices and the edges to the graph store
|
||||
self.upsert_entities(entities)
|
||||
self.upsert_chunks(chunks)
|
||||
self.upsert_documents(documents)
|
||||
self.upsert_edge(
|
||||
doc_include_chunk,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.DOCUMENT.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_include_chunk,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_include_entity,
|
||||
GraphElemType.INCLUDE.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
chunk_next_chunk,
|
||||
GraphElemType.NEXT.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
GraphElemType.CHUNK.value,
|
||||
)
|
||||
self.upsert_edge(
|
||||
relation,
|
||||
GraphElemType.RELATION.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
GraphElemType.ENTITY.value,
|
||||
)
|
||||
|
||||
def delete_document(self, chunk_ids: str) -> None:
|
||||
"""Delete document in the graph."""
|
||||
chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")]
|
||||
del_chunk_gql = (
|
||||
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
|
||||
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"
|
||||
)
|
||||
del_relation_gql = (
|
||||
f"MATCH(m:{GraphElemType.ENTITY.value})-[r:"
|
||||
f"{GraphElemType.RELATION.value}]-(n:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE r._chunk_id IN {chunkids_list} DELETE r"
|
||||
)
|
||||
delete_only_vertex = "MATCH (n) WHERE NOT EXISTS((n)-[]-()) DELETE n"
|
||||
self.graph_store.conn.run(del_chunk_gql)
|
||||
self.graph_store.conn.run(del_relation_gql)
|
||||
self.graph_store.conn.run(delete_only_vertex)
|
||||
|
||||
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
|
||||
"""Delete triplet."""
|
||||
del_query = (
|
||||
f"MATCH (n1:{GraphElemType.ENTITY.value} {{id:'{sub}'}})"
|
||||
f"-[r:{GraphElemType.RELATION.value} {{id:'{rel}'}}]->"
|
||||
f"(n2:{GraphElemType.ENTITY.value} {{id:'{obj}'}}) DELETE n1,n2,r"
|
||||
)
|
||||
self.graph_store.conn.run(query=del_query)
|
||||
|
||||
def drop(self):
|
||||
"""Delete Graph."""
|
||||
self.graph_store.conn.delete_graph(self.get_graph_config().name)
|
||||
|
||||
def create_graph(self, graph_name: str):
|
||||
"""Create a graph."""
|
||||
self.graph_store.conn.create_graph(graph_name=graph_name)
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_propertity_schema(
|
||||
name: str,
|
||||
type: str = "STRING",
|
||||
optional: bool = False,
|
||||
index: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, str | bool]:
|
||||
"""Format the property for TuGraph.
|
||||
|
||||
Args:
|
||||
name: The name of the property.
|
||||
type: The type of the property.
|
||||
optional: The optional of the property.
|
||||
index: The index of the property.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The formatted property.
|
||||
"""
|
||||
property: Dict[str, str | bool] = {
|
||||
"name": name,
|
||||
"type": type,
|
||||
"optional": optional,
|
||||
}
|
||||
|
||||
if index is not None:
|
||||
property["index"] = index
|
||||
|
||||
# Add any additional keyword arguments to the property dictionary
|
||||
property.update(kwargs)
|
||||
return property
|
||||
|
||||
# Create the graph label for document vertex
|
||||
document_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for chunk vertex
|
||||
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("content", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for entity vertex
|
||||
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for relation edge
|
||||
edge_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for include edge
|
||||
include_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
|
||||
)
|
||||
|
||||
# Create the graph label for next edge
|
||||
next_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
|
||||
)
|
||||
|
||||
if self.graph_store._enable_summary:
|
||||
self.graph_store._upload_plugin()
|
||||
|
||||
def create_graph_label(
|
||||
self,
|
||||
graph_elem_type: GraphElemType,
|
||||
graph_properties: List[Dict[str, Union[str, bool]]],
|
||||
) -> None:
|
||||
"""Create a graph label.
|
||||
|
||||
The graph label is used to identify and distinguish different types of nodes
|
||||
(vertices) and edges in the graph.
|
||||
"""
|
||||
if graph_elem_type.is_vertex(): # vertex
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "VERTEX",
|
||||
"primary": "id",
|
||||
"properties": graph_properties,
|
||||
})
|
||||
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
|
||||
|
||||
gql_check_exist = (
|
||||
f"""CALL db.getLabelSchema('VERTEX', '{graph_elem_type.value}')"""
|
||||
)
|
||||
else: # edge
|
||||
|
||||
def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
|
||||
"""Define the edge direction.
|
||||
|
||||
`include` edge: document -> chunk, chunk -> entity
|
||||
`next` edge: chunk -> chunk
|
||||
`relation` edge: entity -> entity
|
||||
"""
|
||||
if graph_elem_type.is_vertex():
|
||||
raise ValueError("The graph element type must be an edge.")
|
||||
if graph_elem_type == GraphElemType.INCLUDE:
|
||||
return [
|
||||
[GraphElemType.DOCUMENT.value, GraphElemType.CHUNK.value],
|
||||
[GraphElemType.CHUNK.value, GraphElemType.ENTITY.value],
|
||||
[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value],
|
||||
]
|
||||
elif graph_elem_type == GraphElemType.NEXT:
|
||||
return [[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value]]
|
||||
elif graph_elem_type == GraphElemType.RELATION:
|
||||
return [[GraphElemType.ENTITY.value, GraphElemType.ENTITY.value]]
|
||||
else:
|
||||
raise ValueError("Invalid graph element type.")
|
||||
|
||||
data = json.dumps({
|
||||
"label": graph_elem_type.value,
|
||||
"type": "EDGE",
|
||||
"constraints": edge_direction(graph_elem_type),
|
||||
"properties": graph_properties,
|
||||
})
|
||||
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
|
||||
|
||||
gql_check_exist = (
|
||||
f"""CALL db.getLabelSchema('EDGE', '{graph_elem_type.value}')"""
|
||||
)
|
||||
|
||||
# Make sure the graph label is identical
|
||||
try:
|
||||
self.graph_store.conn.run(
|
||||
gql_check_exist
|
||||
) # if not exist, qurying raises an exception
|
||||
except Exception:
|
||||
self.graph_store.conn.run(gql) # create the graph label
|
||||
return
|
||||
|
||||
logger.info(f"Graph label {graph_elem_type.value} already exists.")
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate Graph."""
|
||||
gql = "MATCH (n) DELETE n"
|
||||
self.graph_store.conn.run(gql)
|
||||
|
||||
def check_label(self, graph_elem_type: GraphElemType) -> bool:
|
||||
"""Check if the label exists in the graph.
|
||||
|
||||
Args:
|
||||
graph_elem_type (GraphElemType): The type of the graph element.
|
||||
|
||||
Returns:
|
||||
True if the label exists in the specified graph element type, otherwise
|
||||
False.
|
||||
"""
|
||||
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
|
||||
|
||||
if graph_elem_type.is_vertex():
|
||||
return graph_elem_type in vertex_tables
|
||||
else:
|
||||
return graph_elem_type in edge_tables
|
||||
|
||||
def explore(
|
||||
self,
|
||||
subs: List[str],
|
||||
direct: Direction = Direction.BOTH,
|
||||
depth: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
search_scope: Optional[
|
||||
Literal["knowledge_graph", "document_graph"]
|
||||
] = "knowledge_graph",
|
||||
) -> MemoryGraph:
|
||||
"""Explore the graph from given subjects up to a depth."""
|
||||
if not subs:
|
||||
return MemoryGraph()
|
||||
|
||||
if depth is None or depth < 0 or depth > self.MAX_HIERARCHY_LEVEL:
|
||||
# TODO: to be discussed, be none or MAX_HIERARCHY_LEVEL
|
||||
# depth_string = ".."
|
||||
depth = self.MAX_HIERARCHY_LEVEL
|
||||
depth_string = f"1..{depth}"
|
||||
|
||||
if limit is None:
|
||||
limit_string = ""
|
||||
else:
|
||||
limit_string = f"LIMIT {limit}"
|
||||
|
||||
if search_scope == "knowledge_graph":
|
||||
if direct.name == "OUT":
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]->"
|
||||
elif direct.name == "IN":
|
||||
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
else:
|
||||
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
|
||||
f"{rel}(m:{GraphElemType.ENTITY.value}) "
|
||||
f"WHERE n.id IN {subs} RETURN p {limit_string}"
|
||||
)
|
||||
return self.query(query)
|
||||
else:
|
||||
graph = MemoryGraph()
|
||||
|
||||
for sub in subs:
|
||||
query = (
|
||||
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
|
||||
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-"
|
||||
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS '{sub}' "
|
||||
f"RETURN p {limit_string}"
|
||||
) # if it contains the subjects
|
||||
result = self.query(query)
|
||||
for vertex in result.vertices():
|
||||
graph.upsert_vertex(vertex)
|
||||
for edge in result.edges():
|
||||
graph.append_edge(edge)
|
||||
|
||||
return graph
|
||||
|
||||
def query(self, query: str, **kwargs) -> MemoryGraph:
|
||||
"""Execute a query on graph.
|
||||
|
||||
white_list: List[str] = kwargs.get("white_list", []), which contains the white
|
||||
list of properties and filters the properties that are not in the white list.
|
||||
"""
|
||||
query_result = self.graph_store.conn.run(query=query)
|
||||
white_list: List[str] = kwargs.get(
|
||||
"white_list",
|
||||
[
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"_document_id",
|
||||
"_chunk_id",
|
||||
"_community_id",
|
||||
],
|
||||
)
|
||||
vertices, edges = self._get_nodes_edges_from_queried_data(
|
||||
query_result, white_list
|
||||
)
|
||||
mg = MemoryGraph()
|
||||
for vertex in vertices:
|
||||
mg.upsert_vertex(vertex)
|
||||
for edge in edges:
|
||||
mg.append_edge(edge)
|
||||
return mg
|
||||
|
||||
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
|
||||
"""Execute a stream query."""
|
||||
from neo4j import graph
|
||||
|
||||
async for record in self.graph_store.conn.run_stream(query):
|
||||
mg = MemoryGraph()
|
||||
for key in record.keys():
|
||||
value = record[key]
|
||||
if isinstance(value, graph.Node):
|
||||
node_id = value._properties["id"]
|
||||
description = value._properties["description"]
|
||||
vertex = Vertex(vid=node_id, name=node_id, description=description)
|
||||
mg.upsert_vertex(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
edge_nodes = value.nodes
|
||||
prop_id = value._properties["id"]
|
||||
assert edge_nodes and edge_nodes[0] and edge_nodes[1]
|
||||
src_id = edge_nodes[0]._properties["id"]
|
||||
dst_id = edge_nodes[1]._properties["id"]
|
||||
description = value._properties["description"]
|
||||
edge = Edge(
|
||||
sid=src_id, tid=dst_id, name=prop_id, description=description
|
||||
)
|
||||
mg.append_edge(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
nodes = list(record["p"].nodes)
|
||||
rels = list(record["p"].relationships)
|
||||
formatted_path = []
|
||||
for i in range(len(nodes)):
|
||||
formatted_path.append({
|
||||
"id": nodes[i]._properties["id"],
|
||||
"description": nodes[i]._properties["description"],
|
||||
})
|
||||
if i < len(rels):
|
||||
formatted_path.append({
|
||||
"id": rels[i]._properties["id"],
|
||||
"description": rels[i]._properties["description"],
|
||||
})
|
||||
for i in range(0, len(formatted_path), 2):
|
||||
mg.upsert_vertex(
|
||||
Vertex(
|
||||
vid=formatted_path[i]["id"],
|
||||
name=formatted_path[i]["id"],
|
||||
description=formatted_path[i]["description"],
|
||||
)
|
||||
)
|
||||
if i + 2 < len(formatted_path):
|
||||
mg.append_edge(
|
||||
Edge(
|
||||
sid=formatted_path[i]["id"],
|
||||
tid=formatted_path[i + 2]["id"],
|
||||
name=formatted_path[i + 1]["id"],
|
||||
description=formatted_path[i + 1]["description"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
vertex = Vertex(
|
||||
vid="json_node", name="json_node", description=value
|
||||
)
|
||||
mg.upsert_vertex(vertex)
|
||||
yield mg
|
||||
|
||||
def _get_nodes_edges_from_queried_data(
|
||||
self,
|
||||
data: List[Dict[str, Any]],
|
||||
white_prop_list: List[str],
|
||||
) -> Tuple[List[Vertex], List[Edge]]:
|
||||
"""Format the query data.
|
||||
|
||||
Args:
|
||||
data: The data to be formatted.
|
||||
white_prop_list: The white list of properties.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Vertex], List[Edge]]: The formatted vertices and edges.
|
||||
"""
|
||||
vertex_list: List[Vertex] = []
|
||||
edge_list: List[Edge] = []
|
||||
|
||||
# Remove id, src_id, dst_id and name from the white list
|
||||
# to avoid duplication in the initialisation of the vertex and edge
|
||||
_white_list = [
|
||||
prop
|
||||
for prop in white_prop_list
|
||||
if prop not in ["id", "src_id", "dst_id", "name"]
|
||||
]
|
||||
|
||||
from neo4j import graph
|
||||
|
||||
def filter_properties(
|
||||
properties: dict[str, Any], white_list: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""Filter the properties.
|
||||
|
||||
It will remove the properties that are not in the white list.
|
||||
The expected propertities are:
|
||||
entity_properties = ["id", "name", "description", "_document_id",
|
||||
"_chunk_id", "_community_id"]
|
||||
edge_properties = ["id", "name", "description", "_chunk_id"]
|
||||
"""
|
||||
return {
|
||||
key: value
|
||||
for key, value in properties.items()
|
||||
if (not key.startswith("_") and key not in ["id", "name"])
|
||||
or key in white_list
|
||||
}
|
||||
|
||||
# Parse the data to nodes and relationships
|
||||
for record in data:
|
||||
for value in record.values():
|
||||
if isinstance(value, graph.Node):
|
||||
assert value._properties.get("id")
|
||||
vertex = Vertex(
|
||||
vid=value._properties.get("id", ""),
|
||||
name=value._properties.get("name"),
|
||||
**filter_properties(value._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
# TODO: Do we really need to check it every time?
|
||||
vertex_list.append(vertex)
|
||||
elif isinstance(value, graph.Relationship):
|
||||
for node in value.nodes: # num of nodes is 2
|
||||
assert node and node._properties
|
||||
vertex = Vertex(
|
||||
vid=node._properties.get("id", ""),
|
||||
name=node._properties.get("name"),
|
||||
**filter_properties(node._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
|
||||
assert value.nodes and value.nodes[0] and value.nodes[1]
|
||||
edge = Edge(
|
||||
sid=value.nodes[0]._properties.get("id", ""),
|
||||
tid=value.nodes[1]._properties.get("id", ""),
|
||||
name=value._properties.get("name", ""),
|
||||
**filter_properties(value._properties, _white_list),
|
||||
)
|
||||
if edge not in edge_list:
|
||||
edge_list.append(edge)
|
||||
elif isinstance(value, graph.Path):
|
||||
for rel in value.relationships:
|
||||
for node in rel.nodes: # num of nodes is 2
|
||||
assert node and node._properties
|
||||
vertex = Vertex(
|
||||
vid=node._properties.get("id", ""),
|
||||
name=node._properties.get("name"),
|
||||
**filter_properties(node._properties, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
|
||||
assert rel.nodes and rel.nodes[0] and rel.nodes[1]
|
||||
edge = Edge(
|
||||
sid=rel.nodes[0]._properties.get("id", ""),
|
||||
tid=rel.nodes[1]._properties.get("id", ""),
|
||||
name=rel._properties.get("name", ""),
|
||||
**filter_properties(rel._properties, _white_list),
|
||||
)
|
||||
if edge not in edge_list:
|
||||
edge_list.append(edge)
|
||||
|
||||
else: # json_node
|
||||
vertex = Vertex(
|
||||
vid="json_node",
|
||||
name="json_node",
|
||||
**filter_properties({"description": value}, _white_list),
|
||||
)
|
||||
if vertex not in vertex_list:
|
||||
vertex_list.append(vertex)
|
||||
return vertex_list, edge_list
|
||||
|
||||
def _parser(self, entity_list: List[Dict[str, Any]]) -> str:
|
||||
"""Parse entities to string."""
|
||||
formatted_nodes = [
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
|
||||
for k, v in node.items()
|
||||
)
|
||||
+ "}"
|
||||
for node in entity_list
|
||||
]
|
||||
return f"""{", ".join(formatted_nodes)}"""
|
@@ -2,14 +2,16 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
|
||||
from dbgpt.storage.graph_store.graph import GraphElemType, MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
|
||||
from dbgpt.storage.knowledge_graph.community.factory import CommunityStoreAdapterFactory
|
||||
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
|
||||
from dbgpt.storage.knowledge_graph.knowledge_graph import (
|
||||
BuiltinKnowledgeGraph,
|
||||
BuiltinKnowledgeGraphConfig,
|
||||
@@ -27,7 +29,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
vector_store_type: str = Field(
|
||||
default="Chroma", description="The type of vector store."
|
||||
default="Chroma",
|
||||
description="The type of vector store.",
|
||||
)
|
||||
user: Optional[str] = Field(
|
||||
default=None,
|
||||
@@ -36,7 +39,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
"The password of vector store, "
|
||||
"if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
extract_topk: int = Field(
|
||||
@@ -120,7 +124,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
cfg.score_threshold = self._community_score_threshold
|
||||
|
||||
self._community_store = CommunityStore(
|
||||
CommunityStoreAdapterFactory.create(self._graph_store),
|
||||
GraphStoreAdapterFactory.create(self._graph_store),
|
||||
CommunitySummarizer(self._llm_client, self._model_name),
|
||||
VectorStoreFactory.create(
|
||||
self._vector_store_type,
|
||||
@@ -135,21 +139,165 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist graph."""
|
||||
# todo add doc node
|
||||
for chunk in chunks:
|
||||
# todo add chunk node
|
||||
# todo add relation doc-chunk
|
||||
data_list = self._parse_chunks(chunks) # parse the chunks by def _lod_doc_graph
|
||||
graph_of_all = MemoryGraph()
|
||||
|
||||
# extract graphs and save
|
||||
graphs = await self._graph_extractor.extract(chunk.content)
|
||||
for graph in graphs:
|
||||
self._graph_store.insert_graph(graph)
|
||||
# Support graph search by the document and the chunks
|
||||
if self._graph_store.get_config().enable_document_graph:
|
||||
doc_vid = str(uuid.uuid4())
|
||||
doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node")
|
||||
for chunk_index, chunk in enumerate(data_list):
|
||||
if chunk["parent_id"] != "document":
|
||||
# chunk -> include -> chunk
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=chunk["parent_id"],
|
||||
src_name=chunk["parent_title"],
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
dst_vid=chunk["id"],
|
||||
dst_name=chunk["title"],
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
edge_name=GraphElemType.INCLUDE.value,
|
||||
edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value,
|
||||
)
|
||||
else:
|
||||
# document -> include -> chunk
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=doc_vid,
|
||||
src_name=doc_name,
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.DOCUMENT.value,
|
||||
"content": "",
|
||||
},
|
||||
dst_vid=chunk["id"],
|
||||
dst_name=chunk["title"],
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
edge_name=GraphElemType.INCLUDE.value,
|
||||
edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value,
|
||||
)
|
||||
|
||||
# chunk -> next -> chunk
|
||||
if chunk_index >= 1:
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=data_list[chunk_index - 1]["id"],
|
||||
src_name=data_list[chunk_index - 1]["title"],
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": data_list[chunk_index - 1]["content"],
|
||||
},
|
||||
dst_vid=chunk["id"],
|
||||
dst_name=chunk["title"],
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
edge_name=GraphElemType.NEXT.value,
|
||||
edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value,
|
||||
)
|
||||
|
||||
# Support knowledge graph search by the entities and the relationships
|
||||
if self._graph_store.get_config().enable_triplet_graph:
|
||||
for chunk_index, chunk in enumerate(data_list):
|
||||
# TODO: Use asyncio to extract graph to accelerate the process
|
||||
# (attention to the CAP of the graph db)
|
||||
|
||||
graphs: List[MemoryGraph] = await self._graph_extractor.extract(
|
||||
chunk["content"]
|
||||
)
|
||||
|
||||
for graph in graphs:
|
||||
graph_of_all.upsert_graph(graph)
|
||||
|
||||
# chunk -> include -> entity
|
||||
if self._graph_store.get_config().enable_document_graph:
|
||||
for vertex in graph.vertices():
|
||||
graph_of_all.upsert_vertex_and_edge(
|
||||
src_vid=chunk["id"],
|
||||
src_name=chunk["title"],
|
||||
src_props={
|
||||
"vertex_type": GraphElemType.CHUNK.value,
|
||||
"content": chunk["content"],
|
||||
},
|
||||
dst_vid=vertex.vid,
|
||||
dst_name=vertex.name,
|
||||
dst_props={
|
||||
"vertex_type": GraphElemType.ENTITY.value,
|
||||
"description": vertex.props.get("description", ""),
|
||||
}, # note: description is only used for the entity
|
||||
edge_name=GraphElemType.INCLUDE.value,
|
||||
edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value,
|
||||
)
|
||||
|
||||
self._graph_store_apdater.upsert_graph(graph_of_all)
|
||||
|
||||
# use asyncio.gather
|
||||
# tasks = [self._graph_extractor.extract(chunk.content) for chunk in chunks]
|
||||
# results = await asyncio.gather(*tasks)
|
||||
# for result in results:
|
||||
# self._graph_store_apdater.upsert_graph(result[0])
|
||||
|
||||
# build communities and save
|
||||
|
||||
await self._community_store.build_communities()
|
||||
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
def _parse_chunks(slef, chunks: List[Chunk]):
|
||||
"""Parse the chunks by anlyzing the markdown chunks."""
|
||||
# TODO: Need to refact.
|
||||
data = []
|
||||
for chunk_index, chunk in enumerate(chunks):
|
||||
parent = None
|
||||
directory_keys = list(chunk.metadata.keys())[:-1]
|
||||
parent_level = directory_keys[-2] if len(directory_keys) > 1 else None
|
||||
current_level = directory_keys[-1] if directory_keys else "Header0"
|
||||
|
||||
chunk_data = {
|
||||
"id": chunk.chunk_id,
|
||||
"title": chunk.metadata.get(current_level, "none_header_chunk"),
|
||||
"directory_keys": directory_keys,
|
||||
"level": current_level,
|
||||
"content": chunk.content,
|
||||
"parent_id": None,
|
||||
"parent_title": None,
|
||||
"type": "chunk",
|
||||
"chunk_index": chunk_index,
|
||||
}
|
||||
|
||||
# Find the parent chunk
|
||||
if parent_level:
|
||||
for parent_direct in reversed(directory_keys[:-1]):
|
||||
parent_titile = chunk.metadata.get(parent_direct, None)
|
||||
for n in range(chunk_index - 1, -1, -1):
|
||||
metadata = chunks[n].metadata
|
||||
keys = list(metadata.keys())[:-1]
|
||||
if (
|
||||
metadata
|
||||
and parent_direct == keys[-1]
|
||||
and parent_titile == metadata.get(parent_direct)
|
||||
):
|
||||
parent = chunks[n]
|
||||
chunk_data["parent_id"] = parent.chunk_id
|
||||
chunk_data["parent_title"] = parent_titile
|
||||
break
|
||||
if chunk_index - n > len(directory_keys):
|
||||
break
|
||||
if chunk_data["parent_id"]:
|
||||
break
|
||||
|
||||
if not chunk_data["parent_id"]:
|
||||
chunk_data["parent_id"] = "document"
|
||||
data.append(chunk_data)
|
||||
return data
|
||||
|
||||
async def asimilar_search_with_scores(
|
||||
self,
|
||||
text,
|
||||
@@ -158,7 +306,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve relevant community summaries."""
|
||||
# global search: retrieve relevant community summaries
|
||||
# Global search: retrieve relevant community summaries
|
||||
communities = await self._community_store.search_communities(text)
|
||||
summaries = [
|
||||
f"Section {i + 1}:\n{community.summary}"
|
||||
@@ -166,16 +314,53 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
]
|
||||
context = "\n".join(summaries) if summaries else ""
|
||||
|
||||
# local search: extract keywords and explore subgraph
|
||||
keywords = await self._keyword_extractor.extract(text)
|
||||
subgraph = self._graph_store.explore(keywords, limit=topk).format()
|
||||
logger.info(f"Search subgraph from {len(keywords)} keywords")
|
||||
keywords: List[str] = await self._keyword_extractor.extract(text)
|
||||
|
||||
if not summaries and not subgraph:
|
||||
# Local search: extract keywords and explore subgraph
|
||||
subgraph = MemoryGraph()
|
||||
subgraph_for_doc = MemoryGraph()
|
||||
|
||||
enable_triplet_graph = self._graph_store.get_config().enable_triplet_graph
|
||||
enable_document_graph = self._graph_store.get_config().enable_document_graph
|
||||
|
||||
if enable_triplet_graph:
|
||||
subgraph: MemoryGraph = self._graph_store_apdater.explore(
|
||||
subs=keywords, limit=10, search_scope="knowledge_graph"
|
||||
)
|
||||
|
||||
if enable_document_graph:
|
||||
keywords_for_document_graph = keywords
|
||||
for vertex in subgraph.vertices():
|
||||
keywords_for_document_graph.append(vertex.name)
|
||||
|
||||
subgraph_for_doc = self._graph_store_apdater.explore(
|
||||
subs=keywords_for_document_graph,
|
||||
limit=5,
|
||||
search_scope="document_graph",
|
||||
)
|
||||
else:
|
||||
if enable_document_graph:
|
||||
subgraph_for_doc = self._graph_store_apdater.explore(
|
||||
subs=keywords,
|
||||
limit=10,
|
||||
search_scope="document_graph",
|
||||
)
|
||||
|
||||
knowledge_graph_str = subgraph.format()
|
||||
knowledge_graph_for_doc_str = subgraph_for_doc.format()
|
||||
|
||||
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")
|
||||
|
||||
if not (summaries or knowledge_graph_str or knowledge_graph_for_doc_str):
|
||||
return []
|
||||
|
||||
# merge search results into context
|
||||
content = HYBRID_SEARCH_PT_CN.format(context=context, graph=subgraph)
|
||||
content = HYBRID_SEARCH_PT_CN.format(
|
||||
context=context,
|
||||
knowledge_graph=knowledge_graph_str,
|
||||
knowledge_graph_for_doc=knowledge_graph_for_doc_str,
|
||||
)
|
||||
logger.info(f"Final GraphRAG queried prompt:\n{content}")
|
||||
return [Chunk(content=content)]
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
@@ -200,174 +385,179 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
self._graph_extractor.drop()
|
||||
|
||||
|
||||
HYBRID_SEARCH_PT_CN = (
|
||||
"## 角色\n"
|
||||
"你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息,"
|
||||
"准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。"
|
||||
"\n"
|
||||
"## 技能\n"
|
||||
"### 技能 1: 上下文理解\n"
|
||||
"- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。\n"
|
||||
"- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。\n"
|
||||
"- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。"
|
||||
"### 技能 2: 知识图谱理解\n"
|
||||
"- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息"
|
||||
"和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:\n"
|
||||
"```"
|
||||
"* 实体信息格式:\n"
|
||||
"- (实体名)\n"
|
||||
"- (实体名:实体描述)\n"
|
||||
"- (实体名:实体属性表)\n"
|
||||
"- (文本块ID:文档块内容)\n"
|
||||
"- (目录ID:目录名)\n"
|
||||
"- (文档ID:文档名称)\n"
|
||||
"\n"
|
||||
"* 关系信息的格式:\n"
|
||||
"- (来源实体名)-[关系名]->(目标实体名)\n"
|
||||
"- (来源实体名)-[关系名:关系描述]->(目标实体名)\n"
|
||||
"- (来源实体名)-[关系名:关系属性表]->(目标实体名)\n"
|
||||
"- (文本块实体)-[包含]->(实体名)\n"
|
||||
"- (目录ID)-[包含]->(文本块实体)\n"
|
||||
"- (目录ID)-[包含]->(子目录ID)\n"
|
||||
"- (文档ID)-[包含]->(文本块实体)\n"
|
||||
"- (文档ID)-[包含]->(目录ID)\n"
|
||||
"```"
|
||||
"- 正确地将关系信息中的实体名/ID与实体信息关联,还原出图结构。"
|
||||
"- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。\n"
|
||||
"\n"
|
||||
"## 约束条件\n"
|
||||
"- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。\n"
|
||||
"- 若[知识图谱]没有提供信息,此时应根据[上下文]提供的信息回答问题。"
|
||||
"- 确保以第三人称书写,从客观角度结合[上下文]和[知识图谱]表达的信息回答问题。\n"
|
||||
"- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。\n"
|
||||
"- 避免使用停用词和过于常见的词汇。\n"
|
||||
"\n"
|
||||
"## 参考案例\n"
|
||||
"```\n"
|
||||
"[上下文]:\n"
|
||||
"Section 1:\n"
|
||||
"菲尔・贾伯的大儿子叫雅各布・贾伯。\n"
|
||||
"Section 2:\n"
|
||||
"菲尔・贾伯的小儿子叫比尔・贾伯。\n"
|
||||
"[知识图谱]:\n"
|
||||
"Entities:\n"
|
||||
"(菲尔・贾伯#菲尔兹咖啡创始人)\n"
|
||||
"(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)\n"
|
||||
"(雅各布・贾伯#菲尔・贾伯的儿子)\n"
|
||||
"(美国多地#菲尔兹咖啡的扩展地区)\n"
|
||||
"\n"
|
||||
"Relationships:\n"
|
||||
"(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)\n"
|
||||
"(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)\n"
|
||||
"(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)\n"
|
||||
"(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)\n"
|
||||
"(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)\n"
|
||||
"```\n"
|
||||
"\n"
|
||||
"----\n"
|
||||
"\n"
|
||||
"接下来的[上下文]和[知识图谱]的信息,可以帮助你回答更好地用户的问题。\n"
|
||||
"\n"
|
||||
"[上下文]:\n"
|
||||
"{context}\n"
|
||||
"\n"
|
||||
"[知识图谱]:\n"
|
||||
"{graph}\n"
|
||||
"\n"
|
||||
)
|
||||
HYBRID_SEARCH_PT_CN = """## 角色
|
||||
你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息,
|
||||
准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。
|
||||
|
||||
HYBRID_SEARCH_PT_EN = (
|
||||
"## Role\n"
|
||||
"You excel at combining the information provided in the [Context] with "
|
||||
"information from the [KnowledgeGraph] to accurately and appropriately "
|
||||
"answer user questions, ensuring that you do not output information "
|
||||
"unrelated to the context and knowledge graph.\n"
|
||||
"\n"
|
||||
"## Skills\n"
|
||||
"### Skill 1: Context Understanding\n"
|
||||
"- Accurately understand the information provided in the [Context], "
|
||||
"which may be divided into several sections.\n"
|
||||
"- Each section in the context will start with [Section] "
|
||||
"and may be numbered as needed.\n"
|
||||
"- The context provides a summary description most relevant to the user’s "
|
||||
"question, and it should be used wisely."
|
||||
"### Skill 2: Knowledge Graph Understanding\n"
|
||||
"- Accurately identify entity information in the [Entities:] section and "
|
||||
"relationship information in the [Relationships:] section "
|
||||
"of the [KnowledgeGraph]. The general format for entity "
|
||||
"and relationship information is:\n"
|
||||
"```"
|
||||
"* Entity Information Format:\n"
|
||||
"- (entity_name)\n"
|
||||
"- (entity_name: entity_description)\n"
|
||||
"- (entity_name: entity_property_map)\n"
|
||||
"- (chunk_id: chunk_content)\n"
|
||||
"- (catalog_id: catalog_name)\n"
|
||||
"- (document_id: document_name)\n"
|
||||
"\n"
|
||||
"* Relationship Information Format:\n"
|
||||
"- (source_entity_name)-[relationship_name]->(target_entity_name)\n"
|
||||
"- (source_entity_name)-[relationship_name: relationship_description]->"
|
||||
"(target_entity_name)\n"
|
||||
"- (source_entity_name)-[relationship_name: relationship_property_map]->"
|
||||
"(target_entity_name)\n"
|
||||
"- (chunk_id)-[Contains]->(entity_name)\n"
|
||||
"- (catalog_id)-[Contains]->(chunk_id)\n"
|
||||
"- (catalog_id)-[Contains]->(sub_catalog_id)\n"
|
||||
"- (document_id)-[Contains]->(chunk_id)\n"
|
||||
"- (document_id)-[Contains]->(catalog_id)\n"
|
||||
"```"
|
||||
"- Correctly associate entity names/IDs in the relationship information "
|
||||
"with entity information to restore the graph structure."
|
||||
"- Use the information expressed by the graph structure as detailed "
|
||||
"context for the user's query to assist in generating better answers.\n"
|
||||
"\n"
|
||||
"## Constraints\n"
|
||||
"- Don't describe your thought process in the answer, provide the answer "
|
||||
"to the user's question directly without generating irrelevant information."
|
||||
"- If the [KnowledgeGraph] does not provide information, you should answer "
|
||||
"the question based on the information provided in the [Context]."
|
||||
"- Ensure to write in the third person, responding to questions from "
|
||||
"an objective perspective based on the information combined from the "
|
||||
"[Context] and the [KnowledgeGraph].\n"
|
||||
"- If the provided information is contradictory, resolve the "
|
||||
"contradictions and provide a single, coherent description.\n"
|
||||
"- Avoid using stop words and overly common vocabulary.\n"
|
||||
"\n"
|
||||
"## Reference Example\n"
|
||||
"```\n"
|
||||
"[Context]:\n"
|
||||
"Section 1:\n"
|
||||
"Phil Schiller's eldest son is Jacob Schiller.\n"
|
||||
"Section 2:\n"
|
||||
"Phil Schiller's youngest son is Bill Schiller.\n"
|
||||
"[KnowledgeGraph]:\n"
|
||||
"Entities:\n"
|
||||
"(Phil Jaber#Founder of Philz Coffee)\n"
|
||||
"(Philz Coffee#Coffee brand founded in Berkeley, California)\n"
|
||||
"(Jacob Jaber#Son of Phil Jaber)\n"
|
||||
"(Multiple locations in the USA#Expansion regions of Philz Coffee)\n"
|
||||
"\n"
|
||||
"Relationships:\n"
|
||||
"(Phil Jaber#Created#Philz Coffee"
|
||||
"#Founded in Berkeley, California in 1978)\n"
|
||||
"(Philz Coffee#Located in#Berkeley, California"
|
||||
"#Founding location of Philz Coffee)\n"
|
||||
"(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)\n"
|
||||
"(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)\n"
|
||||
"(Philz Coffee#Expanded to#Multiple locations in the USA"
|
||||
"#Expansion regions of Philz Coffee)\n"
|
||||
"```\n"
|
||||
"\n"
|
||||
"----\n"
|
||||
"\n"
|
||||
"The following information from the [Context] and [KnowledgeGraph] can "
|
||||
"help you better answer user questions.\n"
|
||||
"\n"
|
||||
"[Context]:\n"
|
||||
"{context}\n"
|
||||
"\n"
|
||||
"[KnowledgeGraph]:\n"
|
||||
"{graph}\n"
|
||||
"\n"
|
||||
)
|
||||
## 技能
|
||||
### 技能 1: 上下文理解
|
||||
- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。
|
||||
- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。
|
||||
- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。
|
||||
### 技能 2: 知识图谱理解
|
||||
- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:
|
||||
```
|
||||
* 实体信息格式:
|
||||
- (实体名)
|
||||
- (实体名:实体描述)
|
||||
- (实体名:实体属性表)
|
||||
- (文本块ID:文档块内容)
|
||||
- (目录ID:目录名)
|
||||
- (文档ID:文档名称)
|
||||
|
||||
* 关系信息的格式:
|
||||
- (来源实体名)-[关系名]->(目标实体名)
|
||||
- (来源实体名)-[关系名:关系描述]->(目标实体名)
|
||||
- (来源实体名)-[关系名:关系属性表]->(目标实体名)
|
||||
- (文本块实体)-[包含]->(实体名)
|
||||
- (目录ID)-[包含]->(文本块实体)
|
||||
- (目录ID)-[包含]->(子目录ID)
|
||||
- (文档ID)-[包含]->(文本块实体)
|
||||
- (文档ID)-[包含]->(目录ID)
|
||||
```
|
||||
- 正确地将关系信息中的实体名/ID与实体信息关联,还原出图结构。
|
||||
- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。
|
||||
|
||||
|
||||
## 约束条件
|
||||
- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。
|
||||
- 若[知识图谱]或者[知识库原文]没有提供信息,此时应根据[上下文]提供的信息回答问题。
|
||||
- 确保以第三人称书写,从客观角度结合[上下文]、[知识图谱]和[知识库原文]表达的信息回答问题。
|
||||
- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。
|
||||
- 避免使用停用词和过于常见的词汇。
|
||||
|
||||
## 参考案例
|
||||
```
|
||||
[上下文]:
|
||||
Section 1:
|
||||
菲尔・贾伯的大儿子叫雅各布・贾伯。
|
||||
Section 2:
|
||||
菲尔・贾伯的小儿子叫比尔・贾伯。
|
||||
|
||||
[知识图谱]:
|
||||
Entities:
|
||||
(菲尔・贾伯#菲尔兹咖啡创始人)
|
||||
(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)
|
||||
(雅各布・贾伯#菲尔・贾伯的儿子)
|
||||
(美国多地#菲尔兹咖啡的扩展地区)
|
||||
|
||||
Relationships:
|
||||
(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)
|
||||
(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)
|
||||
(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)
|
||||
(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)
|
||||
(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)
|
||||
|
||||
[知识库原文]:
|
||||
...
|
||||
```
|
||||
|
||||
----
|
||||
|
||||
接下来的[上下文]、[知识图谱]和[知识库原文]的信息,可以帮助你回答更好地用户的问题。
|
||||
|
||||
[上下文]:
|
||||
{context}
|
||||
|
||||
[知识图谱]:
|
||||
{knowledge_graph}
|
||||
|
||||
[知识库原文]
|
||||
{knowledge_graph_for_doc}
|
||||
""" # noqa: E501
|
||||
|
||||
HYBRID_SEARCH_PT_EN = """## Role
|
||||
You excel at combining the information provided in the [Context] with
|
||||
information from the [KnowledgeGraph] to accurately and appropriately
|
||||
answer user questions, ensuring that you do not output information
|
||||
unrelated to the context and knowledge graph.
|
||||
|
||||
## Skills
|
||||
### Skill 1: Context Understanding
|
||||
- Accurately understand the information provided in the [Context],
|
||||
which may be divided into several sections.
|
||||
- Each section in the context will start with [Section]
|
||||
and may be numbered as needed.
|
||||
- The context provides a summary description most relevant to the user's
|
||||
question, and it should be used wisely.
|
||||
### Skill 2: Knowledge Graph Understanding
|
||||
- Accurately identify entity information in the [Entities:] section and
|
||||
relationship information in the [Relationships:] section
|
||||
of the [KnowledgeGraph]. The general format for entity
|
||||
and relationship information is:
|
||||
```
|
||||
* Entity Information Format:
|
||||
- (entity_name)
|
||||
- (entity_name: entity_description)
|
||||
- (entity_name: entity_property_map)
|
||||
- (chunk_id: chunk_content)
|
||||
- (catalog_id: catalog_name)
|
||||
- (document_id: document_name)
|
||||
|
||||
* Relationship Information Format:
|
||||
- (source_entity_name)-[relationship_name]->(target_entity_name)
|
||||
- (source_entity_name)-[relationship_name: relationship_description]->(target_entity_name)
|
||||
- (source_entity_name)-[relationship_name: relationship_property_map]->(target_entity_name)
|
||||
- (chunk_id)-[Contains]->(entity_name)
|
||||
- (catalog_id)-[Contains]->(chunk_id)
|
||||
- (catalog_id)-[Contains]->(sub_catalog_id)
|
||||
- (document_id)-[Contains]->(chunk_id)
|
||||
- (document_id)-[Contains]->(catalog_id)
|
||||
```
|
||||
- Correctly associate entity names/IDs in the relationship information
|
||||
with entity information to restore the graph structure.
|
||||
- Use the information expressed by the graph structure as detailed
|
||||
context for the user's query to assist in generating better answers.
|
||||
|
||||
## Constraints
|
||||
- Don't describe your thought process in the answer, provide the answer
|
||||
to the user's question directly without generating irrelevant information.
|
||||
- If the [KnowledgeGraph] or [Knowledge base original text] does not provide information, you should answer
|
||||
the question based on the information provided in the [Context].
|
||||
- Ensure to write in the third person, responding to questions from
|
||||
an objective perspective based on the information combined from the
|
||||
[Context], the [KnowledgeGraph] and the [Knowledge base original text].
|
||||
- If the provided information is contradictory, resolve the
|
||||
contradictions and provide a single, coherent description.
|
||||
- Avoid using stop words and overly common vocabulary.
|
||||
|
||||
## Reference Example
|
||||
```
|
||||
[Context]:
|
||||
Section 1:
|
||||
Phil Schiller's eldest son is Jacob Schiller.
|
||||
Section 2:
|
||||
Phil Schiller's youngest son is Bill Schiller.
|
||||
|
||||
[KnowledgeGraph]:
|
||||
Entities:
|
||||
(Phil Jaber#Founder of Philz Coffee)
|
||||
(Philz Coffee#Coffee brand founded in Berkeley, California)
|
||||
(Jacob Jaber#Son of Phil Jaber)
|
||||
(Multiple locations in the USA#Expansion regions of Philz Coffee)
|
||||
|
||||
Relationships:
|
||||
(Phil Jaber#Created#Philz Coffee#Founded in Berkeley, California in 1978)
|
||||
(Philz Coffee#Located in#Berkeley, California#Founding location of Philz Coffee)
|
||||
(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)
|
||||
(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)
|
||||
(Philz Coffee#Expanded to#Multiple locations in the USA#Expansion regions of Philz Coffee)
|
||||
|
||||
[Knowledge base original text]
|
||||
...
|
||||
```
|
||||
|
||||
----
|
||||
|
||||
The following information from the [Context], [KnowledgeGraph] and [Knowledge base original text]
|
||||
can help you better answer user questions.
|
||||
|
||||
[Context]:
|
||||
{context}
|
||||
|
||||
[KnowledgeGraph]:
|
||||
{knowledge_graph}
|
||||
|
||||
[Knowledge base original text]
|
||||
{knowledge_graph_for_doc}
|
||||
""" # noqa: E501
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Knowledge graph class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
@@ -12,6 +13,8 @@ from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
|
||||
from dbgpt.storage.graph_store.factory import GraphStoreFactory
|
||||
from dbgpt.storage.graph_store.graph import Graph
|
||||
from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig
|
||||
from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,9 +49,10 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
self._model_name = config.model_name
|
||||
self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name)
|
||||
self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name)
|
||||
self._graph_store = self.__init_graph_store(config)
|
||||
self._graph_store: GraphStoreBase = self.__init_graph_store(config)
|
||||
self._graph_store_apdater: GraphStoreAdapter = self.__init_graph_store_adapter()
|
||||
|
||||
def __init_graph_store(self, config) -> GraphStoreBase:
|
||||
def __init_graph_store(self, config: BuiltinKnowledgeGraphConfig) -> GraphStoreBase:
|
||||
def configure(cfg: GraphStoreConfig):
|
||||
cfg.name = config.name
|
||||
cfg.embedding_fn = config.embedding_fn
|
||||
@@ -56,6 +60,9 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.graph_store_type
|
||||
return GraphStoreFactory.create(graph_store_type, configure)
|
||||
|
||||
def __init_graph_store_adapter(self):
|
||||
return GraphStoreAdapterFactory.create(self._graph_store)
|
||||
|
||||
def get_config(self) -> BuiltinKnowledgeGraphConfig:
|
||||
"""Get the knowledge graph config."""
|
||||
return self._config
|
||||
@@ -63,10 +70,10 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist triplets to graph store."""
|
||||
|
||||
async def process_chunk(chunk):
|
||||
async def process_chunk(chunk: Chunk):
|
||||
triplets = await self._triplet_extractor.extract(chunk.content)
|
||||
for triplet in triplets:
|
||||
self._graph_store.insert_triplet(*triplet)
|
||||
self._graph_store_apdater.insert_triplet(*triplet)
|
||||
logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}")
|
||||
return chunk.chunk_id
|
||||
|
||||
@@ -89,7 +96,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
for chunk in chunks:
|
||||
triplets = await self._triplet_extractor.extract(chunk.content)
|
||||
for triplet in triplets:
|
||||
self._graph_store.insert_triplet(*triplet)
|
||||
self._graph_store_apdater.insert_triplet(*triplet)
|
||||
logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}")
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
@@ -116,7 +123,8 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
# extract keywords and explore graph store
|
||||
keywords = await self._keyword_extractor.extract(text)
|
||||
subgraph = self._graph_store.explore(keywords, limit=topk).format()
|
||||
subgraph = self._graph_store_apdater.explore(keywords, limit=topk).format()
|
||||
|
||||
logger.info(f"Search subgraph from {len(keywords)} keywords")
|
||||
|
||||
if not subgraph:
|
||||
@@ -147,12 +155,12 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
|
||||
def query_graph(self, limit: Optional[int] = None) -> Graph:
|
||||
"""Query graph."""
|
||||
return self._graph_store.get_full_graph(limit)
|
||||
return self._graph_store_apdater.get_full_graph(limit)
|
||||
|
||||
def truncate(self) -> List[str]:
|
||||
"""Truncate knowledge graph."""
|
||||
logger.info(f"Truncate graph {self._config.name}")
|
||||
self._graph_store.truncate()
|
||||
self._graph_store_apdater.truncate()
|
||||
|
||||
logger.info("Truncate keyword extractor")
|
||||
self._keyword_extractor.truncate()
|
||||
@@ -165,10 +173,15 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete vector name."""
|
||||
logger.info(f"Drop graph {index_name}")
|
||||
self._graph_store.drop()
|
||||
self._graph_store_apdater.drop()
|
||||
|
||||
logger.info("Drop keyword extractor")
|
||||
self._keyword_extractor.drop()
|
||||
|
||||
logger.info("Drop triplet extractor")
|
||||
self._triplet_extractor.drop()
|
||||
|
||||
def delete_by_ids(self, ids: str) -> List[str]:
|
||||
"""Delete by ids."""
|
||||
self._graph_store_apdater.delete_document(chunk_ids=ids)
|
||||
return []
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""The database manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
@@ -10,7 +10,7 @@ You can refer to the python example file `DB-GPT/examples/rag/graph_rag_example.
|
||||
First, you need to install the `dbgpt` library.
|
||||
|
||||
```bash
|
||||
pip install "dbgpt[rag]>=0.6.0"
|
||||
pip install "dbgpt[graph_rag]>=0.6.1"
|
||||
````
|
||||
|
||||
### Prepare Graph Database
|
||||
@@ -112,7 +112,9 @@ TUGRAPH_HOST=127.0.0.1
|
||||
TUGRAPH_PORT=7687
|
||||
TUGRAPH_USERNAME=admin
|
||||
TUGRAPH_PASSWORD=73@TuGraph
|
||||
GRAPH_COMMUNITY_SUMMARY_ENABLED=True
|
||||
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
|
||||
ENABLE_TRIPLET_GRAPH=True # enable the graph search for the triplets
|
||||
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks
|
||||
```
|
||||
|
||||
|
||||
@@ -250,23 +252,23 @@ Performance testing is based on the `gpt-4o-mini` model.
|
||||
|
||||
#### Indexing Performance
|
||||
|
||||
| | DB-GPT | GraphRAG(microsoft) |
|
||||
|----------|----------|------------------------|
|
||||
| Document Tokens | 42631 | 42631 |
|
||||
| Graph Size | 808 nodes, 1170 edges | 779 nodes, 967 edges |
|
||||
| Prompt Tokens | 452614 | 744990 |
|
||||
| Completion Tokens | 48325 | 227230 |
|
||||
| Total Tokens | 500939 | 972220 |
|
||||
| | DB-GPT | GraphRAG(microsoft) |
|
||||
| ----------------- | --------------------- | -------------------- |
|
||||
| Document Tokens | 42631 | 42631 |
|
||||
| Graph Size | 808 nodes, 1170 edges | 779 nodes, 967 edges |
|
||||
| Prompt Tokens | 452614 | 744990 |
|
||||
| Completion Tokens | 48325 | 227230 |
|
||||
| Total Tokens | 500939 | 972220 |
|
||||
|
||||
|
||||
#### Querying Performance
|
||||
|
||||
**Global Search**
|
||||
|
||||
| | DB-GPT | GraphRAG(microsoft) |
|
||||
|----------|----------|------------------------|
|
||||
| Time | 8s | 40s |
|
||||
| Tokens| 7432 | 63317 |
|
||||
| | DB-GPT | GraphRAG(microsoft) |
|
||||
| ------ | ------ | ------------------- |
|
||||
| Time | 8s | 40s |
|
||||
| Tokens | 7432 | 63317 |
|
||||
|
||||
**Question**
|
||||
```
|
||||
@@ -304,10 +306,10 @@ Performance testing is based on the `gpt-4o-mini` model.
|
||||
|
||||
**Local Search**
|
||||
|
||||
| | DB-GPT | GraphRAG(microsoft) |
|
||||
|----------|----------|------------------------|
|
||||
| Time | 15s | 15s |
|
||||
| Tokens| 9230 | 11619 |
|
||||
| | DB-GPT | GraphRAG(microsoft) |
|
||||
| ------ | ------ | ------------------- |
|
||||
| Time | 15s | 15s |
|
||||
| Tokens | 9230 | 11619 |
|
||||
|
||||
**Question**
|
||||
|
||||
@@ -352,3 +354,28 @@ DB-GPT社区与TuGraph社区的比较
|
||||
总结
|
||||
总体而言,DB-GPT社区和TuGraph社区在社区贡献、生态系统和开发者参与等方面各具特色。DB-GPT社区更侧重于AI应用的多样性和组织间的合作,而TuGraph社区则专注于图数据的高效管理和分析。两者的共同点在于都强调了开源和社区合作的重要性,推动了各自领域的技术进步和应用发展。
|
||||
```
|
||||
|
||||
### Latest Updates
|
||||
|
||||
In version 0.6.1 of DB-GPT, we have added a new feature:
|
||||
- Retrieval of triplets with the **retrieval of document structure**
|
||||
|
||||
We have expanded the definition scope of 'Graph' in GraphRAG:
|
||||
```
|
||||
Knowledge Graph = Triplets Graph + Document Structure Graph
|
||||
```
|
||||
|
||||
<p align="left">
|
||||
<img src={'/img/chat_knowledge/graph_rag/image_graphrag_0_6_1.png'} width="1000px"/>
|
||||
</p>
|
||||
|
||||
How?
|
||||
|
||||
We decompose standard format files (currently best support for Markdown files) into a directed graph based on their hierarchy and layout information, and store it in a graph database. In this graph:
|
||||
- Each node represents a chunk of the file
|
||||
- Each edge represents the structural relationship between different chunks in the original document
|
||||
- Merge the document structure graph to the triplets graph
|
||||
|
||||
What is the next?
|
||||
|
||||
We aim to construct a more complex Graph that covers more comprehensive information to support more sophisticated retrieval algorithms in our GraphRAG.
|
BIN
docs/docs/cookbook/rag/image_graphrag_0_6_1.png
Normal file
BIN
docs/docs/cookbook/rag/image_graphrag_0_6_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 195 KiB |
BIN
docs/static/img/chat_knowledge/graph_rag/image_graphrag_0_6_1.png
vendored
Normal file
BIN
docs/static/img/chat_knowledge/graph_rag/image_graphrag_0_6_1.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 195 KiB |
@@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.conn_tugraph import TuGraphConnector
|
||||
|
||||
# Set database connection parameters.
|
||||
@@ -21,10 +22,10 @@ def connector():
|
||||
|
||||
def test_get_table_names(connector):
|
||||
"""Test retrieving table names from the graph database."""
|
||||
table_names = connector.get_table_names()
|
||||
vertex_tables, edge_tables = connector.get_table_names()
|
||||
# Verify the quantity of vertex and edge tables.
|
||||
assert len(table_names["vertex_tables"]) == 5
|
||||
assert len(table_names["edge_tables"]) == 8
|
||||
assert len(vertex_tables) == 5
|
||||
assert len(edge_tables) == 8
|
||||
|
||||
|
||||
def test_get_columns(connector):
|
||||
|
@@ -4,6 +4,9 @@ from dbgpt.storage.graph_store.memgraph_store import (
|
||||
MemoryGraphStore,
|
||||
MemoryGraphStoreConfig,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.memgraph_store_adapter import (
|
||||
MemGraphStoreAdapter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -11,31 +14,37 @@ def graph_store():
|
||||
yield MemoryGraphStore(MemoryGraphStoreConfig())
|
||||
|
||||
|
||||
def test_graph_store(graph_store):
|
||||
graph_store.insert_triplet("A", "0", "A")
|
||||
graph_store.insert_triplet("A", "1", "A")
|
||||
graph_store.insert_triplet("A", "2", "B")
|
||||
graph_store.insert_triplet("B", "3", "C")
|
||||
graph_store.insert_triplet("B", "4", "D")
|
||||
graph_store.insert_triplet("C", "5", "D")
|
||||
graph_store.insert_triplet("B", "6", "E")
|
||||
graph_store.insert_triplet("F", "7", "E")
|
||||
graph_store.insert_triplet("E", "8", "F")
|
||||
@pytest.fixture
|
||||
def graph_store_adapter(graph_store: MemoryGraphStore):
|
||||
memgraph_store_adapter = MemGraphStoreAdapter(graph_store)
|
||||
yield memgraph_store_adapter
|
||||
|
||||
subgraph = graph_store.explore(["A"])
|
||||
|
||||
def test_graph_store(graph_store_adapter: MemGraphStoreAdapter):
|
||||
graph_store_adapter.insert_triplet("A", "0", "A")
|
||||
graph_store_adapter.insert_triplet("A", "1", "A")
|
||||
graph_store_adapter.insert_triplet("A", "2", "B")
|
||||
graph_store_adapter.insert_triplet("B", "3", "C")
|
||||
graph_store_adapter.insert_triplet("B", "4", "D")
|
||||
graph_store_adapter.insert_triplet("C", "5", "D")
|
||||
graph_store_adapter.insert_triplet("B", "6", "E")
|
||||
graph_store_adapter.insert_triplet("F", "7", "E")
|
||||
graph_store_adapter.insert_triplet("E", "8", "F")
|
||||
|
||||
subgraph = graph_store_adapter.explore(["A"])
|
||||
print(f"\n{subgraph.format()}")
|
||||
assert subgraph.edge_count == 9
|
||||
|
||||
graph_store.delete_triplet("A", "0", "A")
|
||||
graph_store.delete_triplet("B", "4", "D")
|
||||
subgraph = graph_store.explore(["A"])
|
||||
graph_store_adapter.delete_triplet("A", "0", "A")
|
||||
graph_store_adapter.delete_triplet("B", "4", "D")
|
||||
subgraph = graph_store_adapter.explore(["A"])
|
||||
print(f"\n{subgraph.format()}")
|
||||
assert subgraph.edge_count == 7
|
||||
|
||||
triplets = graph_store.get_triplets("B")
|
||||
triplets = graph_store_adapter.get_triplets("B")
|
||||
print(f"\nTriplets of B: {triplets}")
|
||||
assert len(triplets) == 2
|
||||
|
||||
schema = graph_store.get_schema()
|
||||
schema = graph_store_adapter.get_schema()
|
||||
print(f"\nSchema: {schema}")
|
||||
assert len(schema) == 86
|
||||
|
@@ -1,43 +1,52 @@
|
||||
# test_tugraph_store.py
|
||||
# test_tugraph_tugraph_store_adapter.py
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
|
||||
TuGraphStoreAdapter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def store():
|
||||
config = TuGraphStoreConfig(name="TestGraph", summary_enabled=False)
|
||||
config = TuGraphStoreConfig(name="TestGraph", enable_summary=False)
|
||||
store = TuGraphStore(config=config)
|
||||
yield store
|
||||
store.conn.close()
|
||||
|
||||
|
||||
def test_insert_and_get_triplets(store):
|
||||
store.insert_triplet("A", "0", "A")
|
||||
store.insert_triplet("A", "1", "A")
|
||||
store.insert_triplet("A", "2", "B")
|
||||
store.insert_triplet("B", "3", "C")
|
||||
store.insert_triplet("B", "4", "D")
|
||||
store.insert_triplet("C", "5", "D")
|
||||
store.insert_triplet("B", "6", "E")
|
||||
store.insert_triplet("F", "7", "E")
|
||||
store.insert_triplet("E", "8", "F")
|
||||
triplets = store.get_triplets("A")
|
||||
@pytest.fixture(scope="module")
|
||||
def tugraph_store_adapter(store: TuGraphStore):
|
||||
tugraph_store_adapter = TuGraphStoreAdapter(store)
|
||||
yield tugraph_store_adapter
|
||||
|
||||
|
||||
def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter):
|
||||
tugraph_store_adapter.insert_triplet("A", "0", "A")
|
||||
tugraph_store_adapter.insert_triplet("A", "1", "A")
|
||||
tugraph_store_adapter.insert_triplet("A", "2", "B")
|
||||
tugraph_store_adapter.insert_triplet("B", "3", "C")
|
||||
tugraph_store_adapter.insert_triplet("B", "4", "D")
|
||||
tugraph_store_adapter.insert_triplet("C", "5", "D")
|
||||
tugraph_store_adapter.insert_triplet("B", "6", "E")
|
||||
tugraph_store_adapter.insert_triplet("F", "7", "E")
|
||||
tugraph_store_adapter.insert_triplet("E", "8", "F")
|
||||
triplets = tugraph_store_adapter.get_triplets("A")
|
||||
assert len(triplets) == 2
|
||||
triplets = store.get_triplets("B")
|
||||
triplets = tugraph_store_adapter.get_triplets("B")
|
||||
assert len(triplets) == 3
|
||||
triplets = store.get_triplets("C")
|
||||
triplets = tugraph_store_adapter.get_triplets("C")
|
||||
assert len(triplets) == 1
|
||||
triplets = store.get_triplets("D")
|
||||
triplets = tugraph_store_adapter.get_triplets("D")
|
||||
assert len(triplets) == 0
|
||||
triplets = store.get_triplets("E")
|
||||
triplets = tugraph_store_adapter.get_triplets("E")
|
||||
assert len(triplets) == 1
|
||||
triplets = store.get_triplets("F")
|
||||
triplets = tugraph_store_adapter.get_triplets("F")
|
||||
assert len(triplets) == 1
|
||||
|
||||
|
||||
def test_query(store):
|
||||
def test_query(store: TuGraphStore):
|
||||
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
|
||||
result = store.query(query)
|
||||
v_c = result.vertex_count
|
||||
@@ -45,18 +54,18 @@ def test_query(store):
|
||||
assert v_c == 3 and e_c == 3
|
||||
|
||||
|
||||
def test_explore(store):
|
||||
def test_explore(tugraph_store_adapter: TuGraphStoreAdapter):
|
||||
subs = ["A", "B"]
|
||||
result = store.explore(subs, depth=2, fan=None, limit=10)
|
||||
result = tugraph_store_adapter.explore(subs, depth=2, fan=None, limit=10)
|
||||
v_c = result.vertex_count
|
||||
e_c = result.edge_count
|
||||
assert v_c == 5 and e_c == 5
|
||||
|
||||
|
||||
def test_delete_triplet(store):
|
||||
def test_delete_triplet(tugraph_store_adapter: TuGraphStoreAdapter):
|
||||
subj = "A"
|
||||
rel = "0"
|
||||
obj = "B"
|
||||
store.delete_triplet(subj, rel, obj)
|
||||
triplets = store.get_triplets(subj)
|
||||
tugraph_store_adapter.delete_triplet(subj, rel, obj)
|
||||
triplets = tugraph_store_adapter.get_triplets(subj)
|
||||
assert len(triplets) == 0
|
||||
|
@@ -1,18 +1,27 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.storage.graph_store.graph import Edge, MemoryGraph, Vertex
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph, Edge, Vertex
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
|
||||
TuGraphStoreAdapter,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def store():
|
||||
config = TuGraphStoreConfig(name="TestSummaryGraph", summary_enabled=True)
|
||||
config = TuGraphStoreConfig(name="TestSummaryGraph", enable_summary=True)
|
||||
store_instance = TuGraphStore(config=config)
|
||||
yield store_instance
|
||||
store_instance.conn.close()
|
||||
|
||||
|
||||
def test_insert_graph(store):
|
||||
@pytest.fixture(scope="module")
|
||||
def graph_store_adapter(store: TuGraphStore):
|
||||
tugraph_store_adapter = TuGraphStoreAdapter(store)
|
||||
yield tugraph_store_adapter
|
||||
|
||||
|
||||
def test_upsert_graph(tugraph_store_adapter: TuGraphStoreAdapter):
|
||||
graph = MemoryGraph()
|
||||
vertex_list = [
|
||||
Vertex("A", "A", description="Vertex A", _document_id="Test doc"),
|
||||
@@ -35,22 +44,22 @@ def test_insert_graph(store):
|
||||
graph.upsert_vertex(vertex)
|
||||
for edge in edge_list:
|
||||
graph.append_edge(edge)
|
||||
store.insert_graph(graph)
|
||||
tugraph_store_adapter.upsert_graph(graph)
|
||||
|
||||
|
||||
def test_leiden_query(store):
|
||||
def test_leiden_query(store: TuGraphStore):
|
||||
query = "CALL db.plugin.callPlugin('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
result = store.query(query)
|
||||
assert result.vertex_count == 1
|
||||
|
||||
|
||||
def test_query_node_and_edge(store):
|
||||
def test_query_node_and_edge(store: TuGraphStore):
|
||||
query = 'MATCH (n)-[r]->(m) WHERE n._community_id = "0" RETURN n,r,m'
|
||||
result = store.query(query)
|
||||
assert result.vertex_count == 7 and result.edge_count == 6
|
||||
|
||||
|
||||
def test_stream_query_path(store):
|
||||
def test_stream_query_path(store: TuGraphStore):
|
||||
query = 'MATCH p=(n)-[r:relation*2]->(m) WHERE n._community_id = "0" RETURN p'
|
||||
result = store.query(query)
|
||||
for v in result.vertices():
|
||||
|
Reference in New Issue
Block a user