feat: add document structure into GraphRAG (#2033)

Co-authored-by: Appointat <kuda.czk@antgroup.com>
Co-authored-by: tpoisonooo <khj.application@aliyun.com>
Co-authored-by: vritser <vritser@163.com>
This commit is contained in:
lipengfei
2024-10-18 22:03:08 +08:00
committed by GitHub
parent 811ce63493
commit 88e3d12bd3
29 changed files with 1909 additions and 935 deletions

View File

@@ -157,12 +157,15 @@ EXECUTE_LOCAL_COMMANDS=False
#*******************************************************************#
VECTOR_STORE_TYPE=Chroma
GRAPH_STORE_TYPE=TuGraph
GRAPH_COMMUNITY_SUMMARY_ENABLED=True
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE=5
KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
ENABLE_TRIPLET_GRAPH=True # enable the graph search for triplets
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks
### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

View File

@@ -213,8 +213,8 @@ class Config(metaclass=Singleton):
# Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
self.ENABLE_GRAPH_COMMUNITY_SUMMARY = (
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
)
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")

View File

@@ -12,7 +12,6 @@ from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentEntity,
)
from dbgpt.app.knowledge.request.request import (
ChunkEditRequest,
ChunkQueryRequest,
DocumentQueryRequest,
DocumentRecallTestRequest,
@@ -650,12 +649,17 @@ class KnowledgeService:
{
"id": node.vid,
"communityId": node.get_prop("_community_id"),
"name": node.vid,
"type": "",
"name": node.name,
"type": node.get_prop("type") or "",
}
)
for edge in graph.edges():
res["edges"].append(
{"source": edge.sid, "target": edge.tid, "name": edge.name, "type": ""}
{
"source": edge.sid,
"target": edge.tid,
"name": edge.name,
"type": edge.get_prop("type") or "",
}
)
return res

View File

@@ -1,7 +1,7 @@
"""TuGraph Connector."""
import json
from typing import Dict, Generator, List, cast
from typing import Dict, Generator, List, Tuple, cast
from .base import BaseConnector
@@ -21,8 +21,7 @@ class TuGraphConnector(BaseConnector):
self._session = None
def create_graph(self, graph_name: str) -> None:
"""Create a new graph."""
# run the query to get vertex labels
"""Create a new graph in the database if it doesn't already exist."""
try:
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
@@ -32,10 +31,10 @@ class TuGraphConnector(BaseConnector):
f"CALL dbms.graph.createGraph('{graph_name}', '', 2048)"
)
except Exception as e:
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}")
raise Exception(f"Failed to create graph '{graph_name}': {str(e)}") from e
def delete_graph(self, graph_name: str) -> None:
"""Delete a graph."""
"""Delete a graph in the database if it exists."""
with self._driver.session(database="default") as session:
graph_list = session.run("CALL dbms.graph.listGraphs()").data()
exists = any(item["graph_name"] == graph_name for item in graph_list)
@@ -61,17 +60,20 @@ class TuGraphConnector(BaseConnector):
"`pip install neo4j`"
) from err
def get_table_names(self) -> Dict[str, List[str]]:
def get_table_names(self) -> Tuple[List[str], List[str]]:
"""Get all table names from the TuGraph by Neo4j driver."""
# run the query to get vertex labels
with self._driver.session(database=self._graph) as session:
v_result = session.run("CALL db.vertexLabels()").data()
v_data = [table_name["label"] for table_name in v_result]
# Run the query to get vertex labels
raw_vertex_labels: Dict[str, str] = session.run(
"CALL db.vertexLabels()"
).data()
vertex_labels = [table_name["label"] for table_name in raw_vertex_labels]
# run the query to get edge labels
e_result = session.run("CALL db.edgeLabels()").data()
e_data = [table_name["label"] for table_name in e_result]
return {"vertex_tables": v_data, "edge_tables": e_data}
# Run the query to get edge labels
raw_edge_labels: Dict[str, str] = session.run("CALL db.edgeLabels()").data()
edge_labels = [table_name["label"] for table_name in raw_edge_labels]
return vertex_labels, edge_labels
def get_grants(self):
"""Get grants."""
@@ -100,7 +102,7 @@ class TuGraphConnector(BaseConnector):
result = session.run(query)
return list(result)
except Exception as e:
raise Exception(f"Query execution failed: {e}")
raise Exception(f"Query execution failed: {e}\nQuery: {query}") from e
def run_stream(self, query: str) -> Generator:
"""Run GQL."""
@@ -109,11 +111,15 @@ class TuGraphConnector(BaseConnector):
yield from result
def get_columns(self, table_name: str, table_type: str = "vertex") -> List[Dict]:
"""Get fields about specified graph.
"""Retrieve the column for a specified vertex or edge table in the graph db.
This function queries the schema of a given table (vertex or edge) and returns
detailed information about its columns (properties).
Args:
table_name (str): table name (graph name)
table_type (str): table type (vertex or edge)
Returns:
columns: List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
@@ -146,8 +152,8 @@ class TuGraphConnector(BaseConnector):
"""Get table indexes about specified table.
Args:
table_name:(str) table name
table_type:(str'vertex' | 'edge'
table_name (str): table name
table_type (str): 'vertex' | 'edge'
Returns:
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""

View File

@@ -65,7 +65,9 @@ class GraphExtractor(LLMExtractor):
match = re.match(r"\((.*?)#(.*?)\)", line)
if match:
name, summary = [part.strip() for part in match.groups()]
graph.upsert_vertex(Vertex(name, description=summary))
graph.upsert_vertex(
Vertex(name, description=summary, vertex_type="entity")
)
elif current_section == "Relationships":
match = re.match(r"\((.*?)#(.*?)#(.*?)#(.*?)\)", line)
if match:
@@ -74,7 +76,13 @@ class GraphExtractor(LLMExtractor):
]
edge_count += 1
graph.append_edge(
Edge(source, target, name, description=summary)
Edge(
source,
target,
name,
description=summary,
edge_type="relation",
)
)
if limit and edge_count >= limit:

View File

@@ -1,4 +1,5 @@
"""KeywordExtractor class."""
import logging
from typing import List, Optional
@@ -39,12 +40,15 @@ class KeywordExtractor(LLMExtractor):
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]:
keywords = set()
for part in text.split(";"):
for s in part.strip().split(","):
keyword = s.strip()
if keyword:
keywords.add(keyword)
if limit and len(keywords) >= limit:
return list(keywords)
lines = text.replace(":", "\n").split("\n")
for line in lines:
for part in line.split(";"):
for s in part.strip().split(","):
keyword = s.strip()
if keyword:
keywords.add(keyword)
if limit and len(keywords) >= limit:
return list(keywords)
return list(keywords)

View File

@@ -128,7 +128,7 @@ class VectorStoreConnector:
def __rewrite_index_store_type(self, index_store_type):
# Rewrite Knowledge Graph Type
if CFG.GRAPH_COMMUNITY_SUMMARY_ENABLED:
if CFG.ENABLE_GRAPH_COMMUNITY_SUMMARY:
if index_store_type == "KnowledgeGraph":
return "CommunitySummaryKnowledgeGraph"
return index_store_type

View File

@@ -1,11 +1,11 @@
"""Graph store base class."""
import logging
from abc import ABC, abstractmethod
from typing import Generator, List, Optional, Tuple
from typing import Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
from dbgpt.storage.graph_store.graph import Direction, Graph
logger = logging.getLogger(__name__)
@@ -23,78 +23,36 @@ class GraphStoreConfig(BaseModel):
default=None,
description="The embedding function of graph store, optional.",
)
summary_enabled: bool = Field(
enable_summary: bool = Field(
default=False,
description="Enable graph community summary or not.",
)
enable_document_graph: bool = Field(
default=True,
description="Enable document graph search or not.",
)
enable_triplet_graph: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)
class GraphStoreBase(ABC):
"""Graph store base class."""
def __init__(self, config: GraphStoreConfig):
"""Initialize graph store."""
self._config = config
self._conn = None
@abstractmethod
def get_config(self) -> GraphStoreConfig:
"""Get the graph store config."""
@abstractmethod
def get_vertex_type(self) -> str:
"""Get the vertex type."""
def _escape_quotes(self, text: str) -> str:
"""Escape single and double quotes in a string for queries."""
@abstractmethod
def get_edge_type(self) -> str:
"""Get the edge type."""
@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Add triplet."""
@abstractmethod
def insert_graph(self, graph: Graph):
"""Add graph."""
@abstractmethod
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Get triplets."""
@abstractmethod
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete triplet."""
@abstractmethod
def truncate(self):
"""Truncate Graph."""
@abstractmethod
def drop(self):
"""Drop graph."""
@abstractmethod
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""
@abstractmethod
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
@abstractmethod
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore on graph."""
@abstractmethod
def query(self, query: str, **args) -> Graph:
"""Execute a query."""
def aquery(self, query: str, **args) -> Graph:
"""Async execute a query."""
return self.query(query, **args)
@abstractmethod
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""
# @abstractmethod
# def _paser(self, entities: List[Vertex]) -> str:
# """Parse entities to string."""

View File

@@ -1,4 +1,5 @@
"""Graph store factory."""
import logging
from typing import Tuple, Type

View File

@@ -1,4 +1,5 @@
"""Graph definition."""
import itertools
import json
import logging
@@ -6,13 +7,41 @@ import re
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
import networkx as nx
logger = logging.getLogger(__name__)
class GraphElemType(Enum):
"""Type of element in graph."""
DOCUMENT = "document"
CHUNK = "chunk"
ENTITY = "entity" # view as general vertex in the general case
RELATION = "relation" # view as general edge in the general case
INCLUDE = "include"
NEXT = "next"
DOCUMENT_INCLUDE_CHUNK = "document_include_chunk"
CHUNK_INCLUDE_CHUNK = "chunk_include_chunk"
CHUNK_INCLUDE_ENTITY = "chunk_include_entity"
CHUNK_NEXT_CHUNK = "chunk_next_chunk"
def is_vertex(self) -> bool:
"""Check if the element is a vertex."""
return self in [
GraphElemType.DOCUMENT,
GraphElemType.CHUNK,
GraphElemType.ENTITY,
]
def is_edge(self) -> bool:
"""Check if the element is an edge."""
return not self.is_vertex()
class Direction(Enum):
"""Direction class."""
@@ -41,7 +70,7 @@ class Elem(ABC):
def set_prop(self, key: str, value: Any):
"""Set a property of ELem."""
self._props[key] = value
self._props[key] = value # note: always update the value
def get_prop(self, key: str):
"""Get one of the properties of Elem."""
@@ -124,6 +153,18 @@ class Edge(Elem):
for k, v in props.items():
self.set_prop(k, v)
def __eq__(self, other):
"""Check if two edges are equal.
Let's say two edges are equal if they have the same source vertex ID,
target vertex ID, and edge label. The properties are not considered.
"""
return (self.sid, self.tid, self.name) == (other.sid, other.tid, other.name)
def __hash__(self):
"""Return the hash value of the edge."""
return hash((self.sid, self.tid, self.name))
@property
def sid(self) -> str:
"""Return the source vertex ID of the edge."""
@@ -188,11 +229,15 @@ class Graph(ABC):
"""Get neighbor edges."""
@abstractmethod
def vertices(self) -> Iterator[Vertex]:
def vertices(
self, filter_fn: Optional[Callable[[Vertex], bool]] = None
) -> Iterator[Vertex]:
"""Get vertex iterator."""
@abstractmethod
def edges(self) -> Iterator[Edge]:
def edges(
self, filter_fn: Optional[Callable[[Edge], bool]] = None
) -> Iterator[Edge]:
"""Get edge iterator."""
@abstractmethod
@@ -241,7 +286,7 @@ class MemoryGraph(Graph):
self._edge_prop_keys = set()
self._edge_count = 0
# init vertices, out edges, in edges index
# vertices index, out edges index, in edges index
self._vs: Any = defaultdict()
self._oes: Any = defaultdict(lambda: defaultdict(set))
self._ies: Any = defaultdict(lambda: defaultdict(set))
@@ -269,7 +314,7 @@ class MemoryGraph(Graph):
# update metadata
self._vertex_prop_keys.update(vertex.props.keys())
def append_edge(self, edge: Edge):
def append_edge(self, edge: Edge) -> bool:
"""Append an edge if it doesn't exist; requires edge label."""
sid = edge.sid
tid = edge.tid
@@ -290,6 +335,34 @@ class MemoryGraph(Graph):
self._edge_count += 1
return True
def upsert_vertex_and_edge(
self,
src_vid: str,
src_name: str,
src_props: Dict[str, Any],
dst_vid: str,
dst_name: str,
dst_props: Dict[str, Any],
edge_name: str,
edge_type: str,
):
"""Uperst src and dst vertex, and edge."""
src_vertex = Vertex(src_vid, src_name, **src_props)
dst_vertex = Vertex(dst_vid, dst_name, **dst_props)
edge = Edge(src_vid, dst_vid, edge_name, **{"edge_type": edge_type})
self.upsert_vertex(src_vertex)
self.upsert_vertex(dst_vertex)
self.append_edge(edge)
def upsert_graph(self, graph: "MemoryGraph"):
"""Upsert a graph."""
for vertex in graph.vertices():
self.upsert_vertex(vertex)
for edge in graph.edges():
self.append_edge(edge)
def has_vertex(self, vid: str) -> bool:
"""Retrieve a vertex by ID."""
return vid in self._vs
@@ -335,13 +408,26 @@ class MemoryGraph(Graph):
return itertools.islice(es, limit) if limit else es
def vertices(self) -> Iterator[Vertex]:
def vertices(
self, filter_fn: Optional[Callable[[Vertex], bool]] = None
) -> Iterator[Vertex]:
"""Return vertices."""
return iter(self._vs.values())
# Get all vertices in the graph
all_vertices = self._vs.values()
def edges(self) -> Iterator[Edge]:
return all_vertices if filter_fn is None else filter(filter_fn, all_vertices)
def edges(
self, filter_fn: Optional[Callable[[Edge], bool]] = None
) -> Iterator[Edge]:
"""Return edges."""
return iter(e for nbs in self._oes.values() for es in nbs.values() for e in es)
# Get all edges in the graph
all_edges = (e for nbs in self._oes.values() for es in nbs.values() for e in es)
if filter_fn is None:
return all_edges
else:
return filter(filter_fn, all_edges)
def del_vertices(self, *vids: str):
"""Delete specified vertices."""
@@ -353,7 +439,7 @@ class MemoryGraph(Graph):
"""Delete edges."""
old_edge_cnt = len(self._oes[sid][tid])
def remove_matches(es):
def remove_matches(es: Set[Edge]):
return set(
filter(
lambda e: not (
@@ -452,7 +538,7 @@ class MemoryGraph(Graph):
]
}
def format(self) -> str:
def format(self, entities_only: Optional[bool] = False) -> str:
"""Format graph to string."""
vs_str = "\n".join(v.format() for v in self.vertices())
es_str = "\n".join(
@@ -461,11 +547,14 @@ class MemoryGraph(Graph):
f"{self.get_vertex(e.tid).format(concise=True)}"
for e in self.edges()
)
return (
f"Entities:\n{vs_str}\n\n" f"Relationships:\n{es_str}"
if (vs_str or es_str)
else ""
)
if entities_only:
return f"Entities:\n{vs_str}" if vs_str else ""
else:
return (
f"Entities:\n{vs_str}\n\nRelationships:\n{es_str}"
if (vs_str or es_str)
else ""
)
def truncate(self):
"""Truncate graph."""

View File

@@ -1,11 +1,10 @@
"""Memory graph store."""
import json
import logging
from typing import Generator, List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph
from dbgpt.storage.graph_store.graph import MemoryGraph
logger = logging.getLogger(__name__)
@@ -28,77 +27,8 @@ class MemoryGraphStore(GraphStoreBase):
"""Get the graph store config."""
return self._graph_store_config
def get_edge_type(self) -> str:
"""Get the edge type."""
raise NotImplementedError("Memory graph store does not have edge type")
def get_vertex_type(self) -> str:
"""Get the vertex type."""
raise NotImplementedError("Memory graph store does not have vertex type")
def insert_triplet(self, sub: str, rel: str, obj: str):
"""Insert a triplet into the graph."""
self._graph.append_edge(Edge(sub, obj, rel))
def insert_graph(self, graph: Graph):
"""Add graph."""
for vertex in graph.vertices():
self._graph.upsert_vertex(vertex)
for edge in graph.edges():
self._graph.append_edge(edge)
def get_triplets(self, sub: str) -> List[Tuple[str, str]]:
"""Retrieve triplets originating from a subject."""
subgraph = self.explore([sub], direct=Direction.OUT, depth=1)
return [(e.name, e.tid) for e in subgraph.edges()]
def delete_triplet(self, sub: str, rel: str, obj: str):
"""Delete a specific triplet from the graph."""
self._graph.del_edges(sub, obj, rel)
def truncate(self):
"""Truncate graph."""
self._graph.truncate()
def drop(self):
"""Drop graph."""
self._graph = None
def get_schema(self, refresh: bool = False) -> str:
"""Return the graph schema as a JSON string."""
return json.dumps(self._graph.schema())
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Return self."""
if not limit:
return self._graph
subgraph = MemoryGraph()
for count, edge in enumerate(self._graph.edges()):
if count >= limit:
break
subgraph.upsert_vertex(self._graph.get_vertex(edge.sid))
subgraph.upsert_vertex(self._graph.get_vertex(edge.tid))
subgraph.append_edge(edge)
count += 1
return subgraph
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
return self._graph.search(subs, direct, depth, fan, limit)
def query(self, query: str, **args) -> Graph:
"""Execute a query on graph."""
raise NotImplementedError("Query memory graph not allowed")
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute stream query."""
raise NotImplementedError("Stream query memory graph not allowed")
def _escape_quotes(self, text: str) -> str:
"""Escape single and double quotes in a string for queries."""
raise NotImplementedError(
"_escape_quotes is not implemented by MemoryGraphStore"
)

View File

@@ -1,14 +1,15 @@
"""TuGraph store."""
import base64
import json
import logging
import os
from typing import Any, Generator, Iterator, List, Optional, Tuple
from typing import List
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.datasource.conn_tugraph import TuGraphConnector
from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.graph import Direction, Edge, Graph, MemoryGraph, Vertex
from dbgpt.storage.graph_store.graph import GraphElemType
logger = logging.getLogger(__name__)
@@ -35,12 +36,28 @@ class TuGraphStoreConfig(GraphStoreConfig):
description="login password",
)
vertex_type: str = Field(
default="entity",
description="The type of vertex, `entity` by default.",
default=GraphElemType.ENTITY.value,
description="The type of entity vertex, `entity` by default.",
)
document_type: str = Field(
default=GraphElemType.DOCUMENT.value,
description="The type of document vertex, `document` by default.",
)
chunk_type: str = Field(
default=GraphElemType.CHUNK.value,
description="The type of chunk vertex, `relation` by default.",
)
edge_type: str = Field(
default="relation",
description="The type of edge, `relation` by default.",
default=GraphElemType.RELATION.value,
description="The type of relation edge, `relation` by default.",
)
include_type: str = Field(
default=GraphElemType.INCLUDE.value,
description="The type of include edge, `include` by default.",
)
next_type: str = Field(
default=GraphElemType.NEXT.value,
description="The type of next edge, `next` by default.",
)
plugin_names: List[str] = Field(
default=["leiden"],
@@ -62,17 +79,24 @@ class TuGraphStore(GraphStoreBase):
self._port = int(os.getenv("TUGRAPH_PORT", config.port))
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
self._summary_enabled = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.summary_enabled
self._enable_summary = (
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
or config.enable_summary
)
self._enable_document_graph = (
os.getenv("ENABLE_DOCUMENT_GRAPH", "").lower() == "true"
or config.enable_document_graph
)
self._enable_triplet_graph = (
os.getenv("ENABLE_TRIPLET_GRAPH", "").lower() == "true"
or config.enable_triplet_graph
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
or config.plugin_names
)
self._graph_name = config.name
self._vertex_type = os.getenv("TUGRAPH_VERTEX_TYPE", config.vertex_type)
self._edge_type = os.getenv("TUGRAPH_EDGE_TYPE", config.edge_type)
self.conn = TuGraphConnector.from_uri_db(
host=self._host,
@@ -82,34 +106,29 @@ class TuGraphStore(GraphStoreBase):
db_name=config.name,
)
self._create_graph(config.name)
def get_vertex_type(self) -> str:
"""Get the vertex type."""
return self._vertex_type
def get_edge_type(self) -> str:
"""Get the edge type."""
return self._edge_type
def _create_graph(self, graph_name: str):
self.conn.create_graph(graph_name=graph_name)
self._create_schema()
if self._summary_enabled:
self._upload_plugin()
def _check_label(self, elem_type: str):
result = self.conn.get_table_names()
if elem_type == "vertex":
return self._vertex_type in result["vertex_tables"]
if elem_type == "edge":
return self._edge_type in result["edge_tables"]
def get_config(self) -> TuGraphStoreConfig:
"""Get the TuGraph store config."""
return self._config
def _add_vertex_index(self, field_name):
gql = f"CALL db.addIndex('{self._vertex_type}', '{field_name}', false)"
"""Add an index to the vertex table."""
# TODO: Not used in the current implementation.
gql = f"CALL db.addIndex('{GraphElemType.ENTITY.value}', '{field_name}', false)"
self.conn.run(gql)
def _upload_plugin(self):
"""Upload missing plugins to the TuGraph database.
This method checks for the presence of required plugins in the database and
uploads any missing plugins. It performs the following steps:
1. Lists existing plugins in the database.
2. Identifies missing plugins by comparing with the required plugin list.
3. For each missing plugin, reads its binary content, encodes it, and uploads to
the database.
The method uses the 'leiden' plugin as an example, but can be extended for other
plugins.
"""
gql = "CALL db.plugin.listPlugin('CPP','v1')"
result = self.conn.run(gql)
result_names = [
@@ -122,8 +141,8 @@ class TuGraphStore(GraphStoreBase):
if len(missing_plugins):
for name in missing_plugins:
try:
from dbgpt_tugraph_plugins import ( # type: ignore # noqa
get_plugin_binary_path,
from dbgpt_tugraph_plugins import (
get_plugin_binary_path, # type:ignore[import-untyped]
)
except ImportError:
logger.error(
@@ -136,375 +155,12 @@ class TuGraphStore(GraphStoreBase):
content = f.read()
content = base64.b64encode(content).decode()
gql = (
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', "
"'SO', '{name} Plugin', false, 'v1')"
f"CALL db.plugin.loadPlugin('CPP', '{name}', '{content}', 'SO', "
f"'{name} Plugin', false, 'v1')"
)
self.conn.run(gql)
def _create_schema(self):
if not self._check_label("vertex"):
if self._summary_enabled:
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._vertex_type}', "
f"'id', ['id',string,false],"
f"['name',string,false],"
f"['_document_id',string,true],"
f"['_chunk_id',string,true],"
f"['_community_id',string,true],"
f"['description',string,true])"
)
self.conn.run(create_vertex_gql)
self._add_vertex_index("_community_id")
else:
create_vertex_gql = (
f"CALL db.createLabel("
f"'vertex', '{self._vertex_type}', "
f"'id', ['id',string,false],"
f"['name',string,false])"
)
self.conn.run(create_vertex_gql)
if not self._check_label("edge"):
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_type}',
'[["{self._vertex_type}",
"{self._vertex_type}"]]',
["id",STRING,false],
["name",STRING,false])"""
if self._summary_enabled:
create_edge_gql = f"""CALL db.createLabel(
'edge', '{self._edge_type}',
'[["{self._vertex_type}",
"{self._vertex_type}"]]',
["id",STRING,false],
["name",STRING,false],
["description",STRING,true])"""
self.conn.run(create_edge_gql)
def _format_query_data(self, data, white_prop_list: List[str]):
nodes_list = []
rels_list: List[Any] = []
_white_list = white_prop_list
from neo4j import graph
def get_filtered_properties(properties, white_list):
return {
key: value
for key, value in properties.items()
if (not key.startswith("_") and key not in ["id", "name"])
or key in white_list
}
def process_node(node: graph.Node):
node_id = node._properties.get("id")
node_name = node._properties.get("name")
node_properties = get_filtered_properties(node._properties, _white_list)
nodes_list.append(
{"id": node_id, "name": node_name, "properties": node_properties}
)
def process_relationship(rel: graph.Relationship):
name = rel._properties.get("name", "")
rel_nodes = rel.nodes
src_id = rel_nodes[0]._properties.get("id")
dst_id = rel_nodes[1]._properties.get("id")
for node in rel_nodes:
process_node(node)
edge_properties = get_filtered_properties(rel._properties, _white_list)
if not any(
existing_edge.get("name") == name
and existing_edge.get("src_id") == src_id
and existing_edge.get("dst_id") == dst_id
for existing_edge in rels_list
):
rels_list.append(
{
"src_id": src_id,
"dst_id": dst_id,
"name": name,
"properties": edge_properties,
}
)
def process_path(path: graph.Path):
for rel in path.relationships:
process_relationship(rel)
def process_other(value):
if not any(
existing_node.get("id") == "json_node" for existing_node in nodes_list
):
nodes_list.append(
{
"id": "json_node",
"name": "json_node",
"properties": {"description": value},
}
)
for record in data:
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
process_node(value)
elif isinstance(value, graph.Relationship):
process_relationship(value)
elif isinstance(value, graph.Path):
process_path(value)
else:
process_other(value)
nodes = [
Vertex(node["id"], node["name"], **node["properties"])
for node in nodes_list
]
rels = [
Edge(edge["src_id"], edge["dst_id"], edge["name"], **edge["properties"])
for edge in rels_list
]
return {"nodes": nodes, "edges": rels}
def get_config(self):
"""Get the graph store config."""
return self._config
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
query = (
f"MATCH (n1:{self._vertex_type})-[r]->(n2:{self._vertex_type}) "
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
)
data = self.conn.run(query)
return [(record["rel"], record["obj"]) for record in data]
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
def escape_quotes(value: str) -> str:
"""Escape single and double quotes in a string for queries."""
return value.replace("'", "\\'").replace('"', '\\"')
subj_escaped = escape_quotes(subj)
rel_escaped = escape_quotes(rel)
obj_escaped = escape_quotes(obj)
node_query = f"""CALL db.upsertVertex(
'{self._vertex_type}',
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
edge_query = f"""CALL db.upsertEdge(
'{self._edge_type}',
{{type:"{self._vertex_type}",key:"sid"}},
{{type:"{self._vertex_type}", key:"tid"}},
[{{sid:"{subj_escaped}",
tid: "{obj_escaped}",
id:"{rel_escaped}",
name: "{rel_escaped}"}}])"""
self.conn.run(query=node_query)
self.conn.run(query=edge_query)
def insert_graph(self, graph: Graph) -> None:
"""Add graph."""
def escape_quotes(value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")
nodes: Iterator[Vertex] = graph.vertices()
edges: Iterator[Edge] = graph.edges()
node_list = []
edge_list = []
def parser(node_list):
formatted_nodes = [
"{"
+ ", ".join(
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
for k, v in node.items()
)
+ "}"
for node in node_list
]
return f"""{', '.join(formatted_nodes)}"""
for node in nodes:
node_list.append(
{
"id": escape_quotes(node.vid),
"name": escape_quotes(node.name),
"description": escape_quotes(node.get_prop("description")) or "",
"_document_id": "0",
"_chunk_id": "0",
"_community_id": "0",
}
)
node_query = (
f"""CALL db.upsertVertex("{self._vertex_type}", [{parser(node_list)}])"""
)
for edge in edges:
edge_list.append(
{
"sid": escape_quotes(edge.sid),
"tid": escape_quotes(edge.tid),
"id": escape_quotes(edge.name),
"name": escape_quotes(edge.name),
"description": escape_quotes(edge.get_prop("description")),
}
)
edge_query = f"""CALL db.upsertEdge(
"{self._edge_type}",
{{type:"{self._vertex_type}", key:"sid"}},
{{type:"{self._vertex_type}", key:"tid"}},
[{parser(edge_list)}])"""
self.conn.run(query=node_query)
self.conn.run(query=edge_query)
def truncate(self):
"""Truncate Graph."""
gql = "MATCH (n) DELETE n"
self.conn.run(gql)
def drop(self):
"""Delete Graph."""
self.conn.delete_graph(self._graph_name)
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
del_query = (
f"MATCH (n1:{self._vertex_type} {{id:'{sub}'}})"
f"-[r:{self._edge_type} {{id:'{rel}'}}]->"
f"(n2:{self._vertex_type} {{id:'{obj}'}}) DELETE n1,n2,r"
)
self.conn.run(query=del_query)
def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the graph store."""
query = "CALL dbms.graph.getGraphSchema()"
data = self.conn.run(query=query)
schema = data[0]["schema"]
return schema
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
if not limit:
raise Exception("limit must be set")
graph_result = self.query(
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
white_list=["_community_id"],
)
all_graph = MemoryGraph()
for vertex in graph_result.vertices():
all_graph.upsert_vertex(vertex)
for edge in graph_result.edges():
all_graph.append_edge(edge)
return all_graph
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
fan: Optional[int] = None,
limit: Optional[int] = None,
) -> Graph:
"""Explore the graph from given subjects up to a depth."""
if not subs:
return MemoryGraph()
if fan is not None:
raise ValueError("Fan functionality is not supported at this time.")
else:
depth_string = f"1..{depth}"
if depth is None:
depth_string = ".."
limit_string = f"LIMIT {limit}"
if limit is None:
limit_string = ""
if direct.name == "OUT":
rel = f"-[r:{self._edge_type}*{depth_string}]->"
elif direct.name == "IN":
rel = f"<-[r:{self._edge_type}*{depth_string}]-"
else:
rel = f"-[r:{self._edge_type}*{depth_string}]-"
query = (
f"MATCH p=(n:{self._vertex_type})"
f"{rel}(m:{self._vertex_type}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}"
)
return self.query(query)
def query(self, query: str, **args) -> MemoryGraph:
"""Execute a query on graph."""
result = self.conn.run(query=query)
white_list = args.get("white_list", [])
graph = self._format_query_data(result, white_list)
mg = MemoryGraph()
for vertex in graph["nodes"]:
mg.upsert_vertex(vertex)
for edge in graph["edges"]:
mg.append_edge(edge)
return mg
def stream_query(self, query: str) -> Generator[Graph, None, None]:
"""Execute a stream query."""
from neo4j import graph
for record in self.conn.run_stream(query):
mg = MemoryGraph()
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
description = value._properties["description"]
vertex = Vertex(node_id, name=node_id, description=description)
mg.upsert_vertex(vertex)
elif isinstance(value, graph.Relationship):
rel_nodes = value.nodes
prop_id = value._properties["id"]
src_id = rel_nodes[0]._properties["id"]
dst_id = rel_nodes[1]._properties["id"]
description = value._properties["description"]
edge = Edge(src_id, dst_id, name=prop_id, description=description)
mg.append_edge(edge)
elif isinstance(value, graph.Path):
nodes = list(record["p"].nodes)
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append(
{
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
}
)
if i < len(rels):
formatted_path.append(
{
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
}
)
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(
formatted_path[i]["id"],
name=formatted_path[i]["id"],
description=formatted_path[i]["description"],
)
)
if i + 2 < len(formatted_path):
mg.append_edge(
Edge(
formatted_path[i]["id"],
formatted_path[i + 2]["id"],
name=formatted_path[i + 1]["id"],
description=formatted_path[i + 1]["description"],
)
)
else:
vertex = Vertex("json_node", name="json_node", description=value)
mg.upsert_vertex(vertex)
yield mg
def _escape_quotes(self, value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")

View File

@@ -1,4 +1,5 @@
"""Knowledge graph base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
@@ -27,6 +28,6 @@ class KnowledgeGraphBase(IndexStoreBase, ABC):
def query_graph(self, limit: Optional[int] = None) -> Graph:
"""Get graph data."""
@abstractmethod
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete document by ids."""
raise Exception("Delete document not supported by knowledge graph")

View File

@@ -1,11 +1,19 @@
"""Define Classes about Community."""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from typing import AsyncGenerator, Iterator, List, Optional
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import Graph
from dbgpt.storage.graph_store.graph import (
Direction,
Edge,
Graph,
GraphElemType,
MemoryGraph,
Vertex,
)
logger = logging.getLogger(__name__)
@@ -24,7 +32,7 @@ class CommunityTree:
"""Represents a community tree."""
class CommunityStoreAdapter(ABC):
class GraphStoreAdapter(ABC):
"""Community Store Adapter."""
def __init__(self, graph_store: GraphStoreBase):
@@ -44,6 +52,113 @@ class CommunityStoreAdapter(ABC):
async def get_community(self, community_id: str) -> Community:
"""Get community."""
@abstractmethod
def get_graph_config(self):
"""Get config."""
@abstractmethod
def get_vertex_type(self) -> str:
"""Get vertex type."""
@abstractmethod
def get_edge_type(self) -> str:
"""Get edge type."""
@abstractmethod
def get_triplets(self, sub: str) -> List[tuple[str, str]]:
"""Get triplets."""
@abstractmethod
def get_document_vertex(self, doc_name: str) -> Vertex:
"""Get document vertex."""
@abstractmethod
def get_schema(self, refresh: bool = False) -> str:
"""Get schema."""
@abstractmethod
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
@abstractmethod
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
"""Upsert entity."""
@abstractmethod
def upsert_edge(
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
):
"""Upsert edge."""
@abstractmethod
def upsert_chunks(self, chunk: Iterator[Vertex]) -> None:
"""Upsert chunk."""
@abstractmethod
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
"""Upsert documents."""
@abstractmethod
def upsert_relations(self, relations: Iterator[Edge]) -> None:
"""Upsert relations."""
@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Insert triplet."""
@abstractmethod
def upsert_graph(self, graph: Graph) -> None:
"""Insert graph."""
@abstractmethod
def delete_document(self, chunk_id: str) -> None:
"""Delete document in graph store."""
@abstractmethod
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
@abstractmethod
def drop(self) -> None:
"""Drop graph."""
@abstractmethod
def create_graph(self, graph_name: str) -> None:
"""Create graph."""
@abstractmethod
def create_graph_label(self) -> None:
"""Create a graph label.
The graph label is used to identify and distinguish different types of nodes
(vertices) and edges in the graph.
"""
@abstractmethod
def truncate(self) -> None:
"""Truncate graph."""
@abstractmethod
def check_label(self, graph_elem_type: GraphElemType) -> bool:
"""Check if the label exists in the graph."""
@abstractmethod
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
@abstractmethod
def query(self, query: str, **kwargs) -> MemoryGraph:
"""Execute a query on graph."""
@abstractmethod
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
"""Execute a stream query."""
class CommunityMetastore(ABC):
"""Community metastore class."""

View File

@@ -4,10 +4,7 @@ import logging
from typing import List
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.storage.knowledge_graph.community.base import (
Community,
CommunityStoreAdapter,
)
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
from dbgpt.storage.knowledge_graph.community.community_metastore import (
BuiltinCommunityMetastore,
)
@@ -21,23 +18,23 @@ class CommunityStore:
def __init__(
self,
community_store_adapter: CommunityStoreAdapter,
graph_store_adapter: GraphStoreAdapter,
community_summarizer: CommunitySummarizer,
vector_store: VectorStoreBase,
):
"""Initialize the CommunityStore class."""
self._community_store_adapter = community_store_adapter
self._graph_store_adapter = graph_store_adapter
self._community_summarizer = community_summarizer
self._meta_store = BuiltinCommunityMetastore(vector_store)
async def build_communities(self):
"""Discover communities."""
community_ids = await self._community_store_adapter.discover_communities()
community_ids = await self._graph_store_adapter.discover_communities()
# summarize communities
communities = []
for community_id in community_ids:
community = await self._community_store_adapter.get_community(community_id)
community = await self._graph_store_adapter.get_community(community_id)
graph = community.data.format()
if not graph:
break
@@ -65,7 +62,7 @@ class CommunityStore:
self._community_summarizer.truncate()
logger.info("Truncate graph")
self._community_store_adapter.graph_store.truncate()
self._graph_store_adapter.truncate()
def drop(self):
"""Drop community store."""
@@ -76,4 +73,4 @@ class CommunityStore:
self._community_summarizer.drop()
logger.info("Remove graph")
self._community_store_adapter.graph_store.drop()
self._graph_store_adapter.drop()

View File

@@ -1,28 +1,29 @@
"""CommunityStoreAdapter factory."""
"""GraphStoreAdapter factory."""
import logging
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
TuGraphCommunityStoreAdapter,
from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
TuGraphStoreAdapter,
)
logger = logging.getLogger(__name__)
class CommunityStoreAdapterFactory:
class GraphStoreAdapterFactory:
"""Factory for community store adapter."""
@staticmethod
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
"""Create a CommunityStoreAdapter instance.
def create(graph_store: GraphStoreBase) -> GraphStoreAdapter:
"""Create a GraphStoreAdapter instance.
Args:
- graph_store_type: graph store type Memory, TuGraph, Neo4j
"""
if isinstance(graph_store, TuGraphStore):
return TuGraphCommunityStoreAdapter(graph_store)
return TuGraphStoreAdapter(graph_store)
else:
raise Exception(
"create community store adapter for %s failed",

View File

@@ -0,0 +1,186 @@
"""TuGraph Community Store Adapter."""
import json
import logging
from typing import AsyncGenerator, Iterator, List, Optional, Tuple
from dbgpt.storage.graph_store.graph import (
Direction,
Edge,
Graph,
GraphElemType,
MemoryGraph,
Vertex,
)
from dbgpt.storage.graph_store.memgraph_store import (
MemoryGraphStore,
MemoryGraphStoreConfig,
)
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
logger = logging.getLogger(__name__)
class MemGraphStoreAdapter(GraphStoreAdapter):
"""MemGraph Community Store Adapter."""
MAX_HIERARCHY_LEVEL = 3
def __init__(self, enable_summary: bool = False):
"""Initialize MemGraph Community Store Adapter."""
self._graph_store = MemoryGraphStore(MemoryGraphStoreConfig())
self._enable_summary = enable_summary
super().__init__(self._graph_store)
# Create the graph
self.create_graph(self._graph_store.get_config().name)
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden."""
pass
async def get_community(self, community_id: str) -> Community:
"""Get community."""
pass
def get_graph_config(self):
"""Get the graph store config."""
return self._graph_store.get_config()
def get_vertex_type(self) -> str:
"""Get the vertex type."""
# raise NotImplementedError("Memory graph store does not have vertex type")
return ""
def get_edge_type(self) -> str:
"""Get the edge type."""
# raise NotImplementedError("Memory graph store does not have edge type")
return ""
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
subgraph = self.explore([subj], direct=Direction.OUT, depth=1)
return [(e.name, e.tid) for e in subgraph.edges()]
def get_document_vertex(self, doc_name: str) -> Vertex:
"""Get the document vertex in the graph."""
raise NotImplementedError("Memory graph store does not have document vertex")
def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the graph store."""
return json.dumps(self._graph_store._graph.schema())
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
if not limit:
return self._graph_store._graph
subgraph = MemoryGraph()
for count, edge in enumerate(self._graph_store._graph.edges()):
if count >= limit:
break
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.sid))
subgraph.upsert_vertex(self._graph_store._graph.get_vertex(edge.tid))
subgraph.append_edge(edge)
count += 1
return subgraph
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
"""Upsert entities."""
pass
def upsert_edge(
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
) -> None:
"""Upsert edges."""
pass
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
"""Upsert chunks."""
pass
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
"""Upsert documents."""
pass
def upsert_relations(self, relations: Iterator[Edge]) -> None:
"""Upsert relations."""
pass
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
self._graph_store._graph.append_edge(Edge(subj, obj, rel))
def upsert_graph(self, graph: Graph) -> None:
"""Add graph to the graph store.
Args:
graph (Graph): The graph to be added.
"""
for vertex in graph.vertices():
self._graph_store._graph.upsert_vertex(vertex)
for edge in graph.edges():
self._graph_store._graph.append_edge(edge)
def delete_document(self, chunk_ids: str) -> None:
"""Delete document in the graph."""
pass
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
self._graph_store._graph.del_edges(sub, obj, rel)
def drop(self):
"""Delete Graph."""
self._graph_store._graph = None
def create_graph(self, graph_name: str):
"""Create a graph."""
pass
def create_graph_label(
self,
) -> None:
"""Create a graph label.
The graph label is used to identify and distinguish different types of nodes
(vertices) and edges in the graph.
"""
pass
def truncate(self):
"""Truncate Graph."""
self._graph_store._graph.truncate()
def check_label(self, graph_elem_type: GraphElemType) -> bool:
"""Check if the label exists in the graph.
Args:
graph_elem_type (GraphElemType): The type of the graph element.
Returns:
True if the label exists in the specified graph element type, otherwise
False.
"""
pass
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: int | None = None,
fan: int | None = None,
limit: int | None = None,
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
return self._graph_store._graph.search(subs, direct, depth, fan, limit)
def query(self, query: str, **kwargs) -> MemoryGraph:
"""Execute a query on graph."""
pass
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
"""Execute a stream query."""
pass

View File

@@ -1,52 +0,0 @@
"""TuGraph Community Store Adapter."""
import json
import logging
from typing import List
from dbgpt.storage.graph_store.graph import MemoryGraph
from dbgpt.storage.knowledge_graph.community.base import (
Community,
CommunityStoreAdapter,
)
logger = logging.getLogger(__name__)
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
"""TuGraph Community Store Adapter."""
MAX_HIERARCHY_LEVEL = 3
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden."""
mg = self._graph_store.query(
"CALL db.plugin.callPlugin"
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
)
result = mg.get_vertex("json_node").get_prop("description")
community_ids = json.loads(result)["community_id_list"]
logger.info(f"Discovered {len(community_ids)} communities.")
return community_ids
async def get_community(self, community_id: str) -> Community:
"""Get community."""
query = (
f"MATCH (n:{self._graph_store.get_vertex_type()})"
f"WHERE n._community_id = '{community_id}' RETURN n"
)
edge_query = (
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
f"[r:{self._graph_store.get_edge_type()}]-"
f"(m:{self._graph_store.get_vertex_type()})"
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
)
all_vertex_graph = self._graph_store.aquery(query)
all_edge_graph = self._graph_store.aquery(edge_query)
all_graph = MemoryGraph()
for vertex in all_vertex_graph.vertices():
all_graph.upsert_vertex(vertex)
for edge in all_edge_graph.edges():
all_graph.append_edge(edge)
return Community(id=community_id, data=all_graph)

View File

@@ -0,0 +1,808 @@
"""TuGraph Community Store Adapter."""
import json
import logging
from typing import (
Any,
AsyncGenerator,
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
Union,
)
from dbgpt.storage.graph_store.graph import (
Direction,
Edge,
Graph,
GraphElemType,
MemoryGraph,
Vertex,
)
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
logger = logging.getLogger(__name__)
class TuGraphStoreAdapter(GraphStoreAdapter):
"""TuGraph Community Store Adapter."""
MAX_QUERY_LIMIT = 1000
MAX_HIERARCHY_LEVEL = 3
def __init__(self, graph_store: TuGraphStore):
"""Initialize TuGraph Community Store Adapter."""
super().__init__(graph_store)
# Create the graph
self.create_graph(self.graph_store.get_config().name)
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden."""
mg = self.query(
"CALL db.plugin.callPlugin('CPP',"
"'leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
)
result = mg.get_vertex("json_node").get_prop("description")
community_ids = json.loads(result)["community_id_list"]
logger.info(f"Discovered {len(community_ids)} communities.")
return community_ids
async def get_community(self, community_id: str) -> Community:
"""Get community."""
query = (
f"MATCH (n:{self.get_vertex_type()}) WHERE n._community_id = "
f"'{community_id}' RETURN n"
)
edge_query = (
f"MATCH (n:{self.get_vertex_type()})-"
f"[r:{self.get_edge_type()}]-"
f"(m:{self.get_vertex_type()})"
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
)
all_vertex_graph = self.query(query)
all_edge_graph = self.query(edge_query)
all_graph = MemoryGraph()
for vertex in all_vertex_graph.vertices():
all_graph.upsert_vertex(vertex)
for edge in all_edge_graph.edges():
all_graph.append_edge(edge)
return Community(id=community_id, data=all_graph)
@property
def graph_store(self) -> TuGraphStore:
"""Get the graph store."""
return self._graph_store
def get_graph_config(self):
"""Get the graph store config."""
return self.graph_store.get_config()
def get_vertex_type(self) -> str:
"""Get the vertex type."""
return GraphElemType.ENTITY.value
def get_edge_type(self) -> str:
"""Get the edge type."""
return GraphElemType.RELATION.value
def get_triplets(self, subj: str) -> List[Tuple[str, str]]:
"""Get triplets."""
triplet_query = (
f"MATCH (n1:{GraphElemType.ENTITY.value})-[r]->(n2:"
f"{GraphElemType.ENTITY.value}) "
f'WHERE n1.id = "{subj}" RETURN r.id as rel, n2.id as obj;'
)
data = self.graph_store.conn.run(triplet_query)
return [(record["rel"], record["obj"]) for record in data]
def get_document_vertex(self, doc_name: str) -> Vertex:
"""Get the document vertex in the graph."""
gql = f"""MATCH (n) WHERE n.id = {doc_name} RETURN n"""
graph = self.query(gql)
vertex = graph.get_vertex(doc_name)
return vertex
def get_schema(self, refresh: bool = False) -> str:
"""Get the schema of the graph store."""
query = "CALL dbms.graph.getGraphSchema()"
data = self.graph_store.conn.run(query=query)
schema = data[0]["schema"]
return schema
def get_full_graph(self, limit: Optional[int] = None) -> Graph:
"""Get full graph."""
if not limit:
limit = self.MAX_QUERY_LIMIT
if limit <= 0:
raise ValueError("Limit must be greater than 0.")
graph_result = self.query(
f"MATCH (n)-[r]-(m) RETURN n,r,m LIMIT {limit}",
white_list=["_community_id"],
)
full_graph = MemoryGraph()
for vertex in graph_result.vertices():
full_graph.upsert_vertex(vertex)
for edge in graph_result.edges():
full_graph.append_edge(edge)
return full_graph
def upsert_entities(self, entities: Iterator[Vertex]) -> None:
"""Upsert entities."""
entity_list = [
{
"id": self.graph_store._escape_quotes(entity.vid),
"name": self.graph_store._escape_quotes(entity.name),
"description": self.graph_store._escape_quotes(
entity.get_prop("description")
)
or "",
"_document_id": "0",
"_chunk_id": "0",
"_community_id": "0",
}
for entity in entities
]
entity_query = (
f"CALL db.upsertVertex("
f'"{GraphElemType.ENTITY.value}", '
f"[{self._parser(entity_list)}])"
)
self.graph_store.conn.run(query=entity_query)
def upsert_edge(
self, edges: Iterator[Edge], edge_type: str, src_type: str, dst_type: str
) -> None:
"""Upsert edges."""
edge_list = [
{
"sid": self.graph_store._escape_quotes(edge.sid),
"tid": self.graph_store._escape_quotes(edge.tid),
"id": self.graph_store._escape_quotes(edge.name),
"name": self.graph_store._escape_quotes(edge.name),
"description": self.graph_store._escape_quotes(
edge.get_prop("description")
)
or "",
"_chunk_id": self.graph_store._escape_quotes(edge.get_prop("_chunk_id"))
or "",
}
for edge in edges
]
relation_query = f"""CALL db.upsertEdge("{edge_type}",
{{type:"{src_type}", key:"sid"}},
{{type:"{dst_type}", key:"tid"}},
[{self._parser(edge_list)}])"""
self.graph_store.conn.run(query=relation_query)
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
"""Upsert chunks."""
chunk_list = [
{
"id": self.graph_store._escape_quotes(chunk.vid),
"name": self.graph_store._escape_quotes(chunk.name),
"content": self.graph_store._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks
]
chunk_query = (
f"CALL db.upsertVertex("
f'"{GraphElemType.CHUNK.value}", '
f"[{self._parser(chunk_list)}])"
)
self.graph_store.conn.run(query=chunk_query)
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
"""Upsert documents."""
document_list = [
{
"id": self.graph_store._escape_quotes(document.vid),
"name": self.graph_store._escape_quotes(document.name),
"content": self.graph_store._escape_quotes(document.get_prop("content"))
or "",
}
for document in documents
]
document_query = (
"CALL db.upsertVertex("
f'"{GraphElemType.DOCUMENT.value}", '
f"[{self._parser(document_list)}])"
)
self.graph_store.conn.run(query=document_query)
def upsert_relations(self, relations: Iterator[Edge]) -> None:
"""Upsert relations."""
pass
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
subj_escaped = subj.replace("'", "\\'").replace('"', '\\"')
rel_escaped = rel.replace("'", "\\'").replace('"', '\\"')
obj_escaped = obj.replace("'", "\\'").replace('"', '\\"')
vertex_query = f"""CALL db.upsertVertex(
'{GraphElemType.ENTITY.value}',
[{{id:'{subj_escaped}',name:'{subj_escaped}'}},
{{id:'{obj_escaped}',name:'{obj_escaped}'}}])"""
edge_query = f"""CALL db.upsertEdge(
'{GraphElemType.RELATION.value}',
{{type:"{GraphElemType.ENTITY.value}",key:"sid"}},
{{type:"{GraphElemType.ENTITY.value}", key:"tid"}},
[{{sid:"{subj_escaped}",
tid: "{obj_escaped}",
id:"{rel_escaped}",
name: "{rel_escaped}"}}])"""
self.graph_store.conn.run(query=vertex_query)
self.graph_store.conn.run(query=edge_query)
def upsert_graph(self, graph: MemoryGraph) -> None:
"""Add graph to the graph store.
Args:
graph (Graph): The graph to be added.
"""
# Get the iterators of all the vertices and the edges from the graph
documents: Iterator[Vertex] = graph.vertices(
filter_fn=lambda x: x.get_prop("vertex_type")
== GraphElemType.DOCUMENT.value
)
chunks: Iterator[Vertex] = graph.vertices(
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.CHUNK.value
)
entities: Iterator[Vertex] = graph.vertices(
filter_fn=lambda x: x.get_prop("vertex_type") == GraphElemType.ENTITY.value
)
doc_include_chunk: Iterator[Edge] = graph.edges(
filter_fn=lambda x: x.get_prop("edge_type")
== GraphElemType.DOCUMENT_INCLUDE_CHUNK.value
)
chunk_include_chunk: Iterator[Edge] = graph.edges(
filter_fn=lambda x: x.get_prop("edge_type")
== GraphElemType.CHUNK_INCLUDE_CHUNK.value
)
chunk_include_entity: Iterator[Edge] = graph.edges(
filter_fn=lambda x: x.get_prop("edge_type")
== GraphElemType.CHUNK_INCLUDE_ENTITY.value
)
chunk_next_chunk: Iterator[Edge] = graph.edges(
filter_fn=lambda x: x.get_prop("edge_type")
== GraphElemType.CHUNK_NEXT_CHUNK.value
)
relation: Iterator[Edge] = graph.edges(
filter_fn=lambda x: x.get_prop("edge_type") == GraphElemType.RELATION.value
)
# Upsert the vertices and the edges to the graph store
self.upsert_entities(entities)
self.upsert_chunks(chunks)
self.upsert_documents(documents)
self.upsert_edge(
doc_include_chunk,
GraphElemType.INCLUDE.value,
GraphElemType.DOCUMENT.value,
GraphElemType.CHUNK.value,
)
self.upsert_edge(
chunk_include_chunk,
GraphElemType.INCLUDE.value,
GraphElemType.CHUNK.value,
GraphElemType.CHUNK.value,
)
self.upsert_edge(
chunk_include_entity,
GraphElemType.INCLUDE.value,
GraphElemType.CHUNK.value,
GraphElemType.ENTITY.value,
)
self.upsert_edge(
chunk_next_chunk,
GraphElemType.NEXT.value,
GraphElemType.CHUNK.value,
GraphElemType.CHUNK.value,
)
self.upsert_edge(
relation,
GraphElemType.RELATION.value,
GraphElemType.ENTITY.value,
GraphElemType.ENTITY.value,
)
def delete_document(self, chunk_ids: str) -> None:
"""Delete document in the graph."""
chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")]
del_chunk_gql = (
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"
)
del_relation_gql = (
f"MATCH(m:{GraphElemType.ENTITY.value})-[r:"
f"{GraphElemType.RELATION.value}]-(n:{GraphElemType.ENTITY.value}) "
f"WHERE r._chunk_id IN {chunkids_list} DELETE r"
)
delete_only_vertex = "MATCH (n) WHERE NOT EXISTS((n)-[]-()) DELETE n"
self.graph_store.conn.run(del_chunk_gql)
self.graph_store.conn.run(del_relation_gql)
self.graph_store.conn.run(delete_only_vertex)
def delete_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Delete triplet."""
del_query = (
f"MATCH (n1:{GraphElemType.ENTITY.value} {{id:'{sub}'}})"
f"-[r:{GraphElemType.RELATION.value} {{id:'{rel}'}}]->"
f"(n2:{GraphElemType.ENTITY.value} {{id:'{obj}'}}) DELETE n1,n2,r"
)
self.graph_store.conn.run(query=del_query)
def drop(self):
"""Delete Graph."""
self.graph_store.conn.delete_graph(self.get_graph_config().name)
def create_graph(self, graph_name: str):
"""Create a graph."""
self.graph_store.conn.create_graph(graph_name=graph_name)
# Create the graph schema
def _format_graph_propertity_schema(
name: str,
type: str = "STRING",
optional: bool = False,
index: Optional[bool] = None,
**kwargs,
) -> Dict[str, str | bool]:
"""Format the property for TuGraph.
Args:
name: The name of the property.
type: The type of the property.
optional: The optional of the property.
index: The index of the property.
kwargs: Additional keyword arguments.
Returns:
The formatted property.
"""
property: Dict[str, str | bool] = {
"name": name,
"type": type,
"optional": optional,
}
if index is not None:
property["index"] = index
# Add any additional keyword arguments to the property dictionary
property.update(kwargs)
return property
# Create the graph label for document vertex
document_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
)
# Create the graph label for chunk vertex
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_propertity_schema("content", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
)
# Create the graph label for entity vertex
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_propertity_schema("description", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
)
# Create the graph label for relation edge
edge_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
_format_graph_propertity_schema("description", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
)
# Create the graph label for include edge
include_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("description", "STRING", True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
)
# Create the graph label for next edge
next_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("description", "STRING", True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
)
if self.graph_store._enable_summary:
self.graph_store._upload_plugin()
def create_graph_label(
self,
graph_elem_type: GraphElemType,
graph_properties: List[Dict[str, Union[str, bool]]],
) -> None:
"""Create a graph label.
The graph label is used to identify and distinguish different types of nodes
(vertices) and edges in the graph.
"""
if graph_elem_type.is_vertex(): # vertex
data = json.dumps({
"label": graph_elem_type.value,
"type": "VERTEX",
"primary": "id",
"properties": graph_properties,
})
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
gql_check_exist = (
f"""CALL db.getLabelSchema('VERTEX', '{graph_elem_type.value}')"""
)
else: # edge
def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
"""Define the edge direction.
`include` edge: document -> chunk, chunk -> entity
`next` edge: chunk -> chunk
`relation` edge: entity -> entity
"""
if graph_elem_type.is_vertex():
raise ValueError("The graph element type must be an edge.")
if graph_elem_type == GraphElemType.INCLUDE:
return [
[GraphElemType.DOCUMENT.value, GraphElemType.CHUNK.value],
[GraphElemType.CHUNK.value, GraphElemType.ENTITY.value],
[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value],
]
elif graph_elem_type == GraphElemType.NEXT:
return [[GraphElemType.CHUNK.value, GraphElemType.CHUNK.value]]
elif graph_elem_type == GraphElemType.RELATION:
return [[GraphElemType.ENTITY.value, GraphElemType.ENTITY.value]]
else:
raise ValueError("Invalid graph element type.")
data = json.dumps({
"label": graph_elem_type.value,
"type": "EDGE",
"constraints": edge_direction(graph_elem_type),
"properties": graph_properties,
})
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
gql_check_exist = (
f"""CALL db.getLabelSchema('EDGE', '{graph_elem_type.value}')"""
)
# Make sure the graph label is identical
try:
self.graph_store.conn.run(
gql_check_exist
) # if not exist, qurying raises an exception
except Exception:
self.graph_store.conn.run(gql) # create the graph label
return
logger.info(f"Graph label {graph_elem_type.value} already exists.")
def truncate(self):
"""Truncate Graph."""
gql = "MATCH (n) DELETE n"
self.graph_store.conn.run(gql)
def check_label(self, graph_elem_type: GraphElemType) -> bool:
"""Check if the label exists in the graph.
Args:
graph_elem_type (GraphElemType): The type of the graph element.
Returns:
True if the label exists in the specified graph element type, otherwise
False.
"""
vertex_tables, edge_tables = self.graph_store.conn.get_table_names()
if graph_elem_type.is_vertex():
return graph_elem_type in vertex_tables
else:
return graph_elem_type in edge_tables
def explore(
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
] = "knowledge_graph",
) -> MemoryGraph:
"""Explore the graph from given subjects up to a depth."""
if not subs:
return MemoryGraph()
if depth is None or depth < 0 or depth > self.MAX_HIERARCHY_LEVEL:
# TODO: to be discussed, be none or MAX_HIERARCHY_LEVEL
# depth_string = ".."
depth = self.MAX_HIERARCHY_LEVEL
depth_string = f"1..{depth}"
if limit is None:
limit_string = ""
else:
limit_string = f"LIMIT {limit}"
if search_scope == "knowledge_graph":
if direct.name == "OUT":
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]->"
elif direct.name == "IN":
rel = f"<-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
else:
rel = f"-[r:{GraphElemType.RELATION.value}*{depth_string}]-"
query = (
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
f"{rel}(m:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}"
)
return self.query(query)
else:
graph = MemoryGraph()
for sub in subs:
query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-"
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS '{sub}' "
f"RETURN p {limit_string}"
) # if it contains the subjects
result = self.query(query)
for vertex in result.vertices():
graph.upsert_vertex(vertex)
for edge in result.edges():
graph.append_edge(edge)
return graph
def query(self, query: str, **kwargs) -> MemoryGraph:
"""Execute a query on graph.
white_list: List[str] = kwargs.get("white_list", []), which contains the white
list of properties and filters the properties that are not in the white list.
"""
query_result = self.graph_store.conn.run(query=query)
white_list: List[str] = kwargs.get(
"white_list",
[
"id",
"name",
"description",
"_document_id",
"_chunk_id",
"_community_id",
],
)
vertices, edges = self._get_nodes_edges_from_queried_data(
query_result, white_list
)
mg = MemoryGraph()
for vertex in vertices:
mg.upsert_vertex(vertex)
for edge in edges:
mg.append_edge(edge)
return mg
async def stream_query(self, query: str, **kwargs) -> AsyncGenerator[Graph, None]:
"""Execute a stream query."""
from neo4j import graph
async for record in self.graph_store.conn.run_stream(query):
mg = MemoryGraph()
for key in record.keys():
value = record[key]
if isinstance(value, graph.Node):
node_id = value._properties["id"]
description = value._properties["description"]
vertex = Vertex(vid=node_id, name=node_id, description=description)
mg.upsert_vertex(vertex)
elif isinstance(value, graph.Relationship):
edge_nodes = value.nodes
prop_id = value._properties["id"]
assert edge_nodes and edge_nodes[0] and edge_nodes[1]
src_id = edge_nodes[0]._properties["id"]
dst_id = edge_nodes[1]._properties["id"]
description = value._properties["description"]
edge = Edge(
sid=src_id, tid=dst_id, name=prop_id, description=description
)
mg.append_edge(edge)
elif isinstance(value, graph.Path):
nodes = list(record["p"].nodes)
rels = list(record["p"].relationships)
formatted_path = []
for i in range(len(nodes)):
formatted_path.append({
"id": nodes[i]._properties["id"],
"description": nodes[i]._properties["description"],
})
if i < len(rels):
formatted_path.append({
"id": rels[i]._properties["id"],
"description": rels[i]._properties["description"],
})
for i in range(0, len(formatted_path), 2):
mg.upsert_vertex(
Vertex(
vid=formatted_path[i]["id"],
name=formatted_path[i]["id"],
description=formatted_path[i]["description"],
)
)
if i + 2 < len(formatted_path):
mg.append_edge(
Edge(
sid=formatted_path[i]["id"],
tid=formatted_path[i + 2]["id"],
name=formatted_path[i + 1]["id"],
description=formatted_path[i + 1]["description"],
)
)
else:
vertex = Vertex(
vid="json_node", name="json_node", description=value
)
mg.upsert_vertex(vertex)
yield mg
def _get_nodes_edges_from_queried_data(
self,
data: List[Dict[str, Any]],
white_prop_list: List[str],
) -> Tuple[List[Vertex], List[Edge]]:
"""Format the query data.
Args:
data: The data to be formatted.
white_prop_list: The white list of properties.
Returns:
Tuple[List[Vertex], List[Edge]]: The formatted vertices and edges.
"""
vertex_list: List[Vertex] = []
edge_list: List[Edge] = []
# Remove id, src_id, dst_id and name from the white list
# to avoid duplication in the initialisation of the vertex and edge
_white_list = [
prop
for prop in white_prop_list
if prop not in ["id", "src_id", "dst_id", "name"]
]
from neo4j import graph
def filter_properties(
properties: dict[str, Any], white_list: List[str]
) -> Dict[str, Any]:
"""Filter the properties.
It will remove the properties that are not in the white list.
The expected propertities are:
entity_properties = ["id", "name", "description", "_document_id",
"_chunk_id", "_community_id"]
edge_properties = ["id", "name", "description", "_chunk_id"]
"""
return {
key: value
for key, value in properties.items()
if (not key.startswith("_") and key not in ["id", "name"])
or key in white_list
}
# Parse the data to nodes and relationships
for record in data:
for value in record.values():
if isinstance(value, graph.Node):
assert value._properties.get("id")
vertex = Vertex(
vid=value._properties.get("id", ""),
name=value._properties.get("name"),
**filter_properties(value._properties, _white_list),
)
if vertex not in vertex_list:
# TODO: Do we really need to check it every time?
vertex_list.append(vertex)
elif isinstance(value, graph.Relationship):
for node in value.nodes: # num of nodes is 2
assert node and node._properties
vertex = Vertex(
vid=node._properties.get("id", ""),
name=node._properties.get("name"),
**filter_properties(node._properties, _white_list),
)
if vertex not in vertex_list:
vertex_list.append(vertex)
assert value.nodes and value.nodes[0] and value.nodes[1]
edge = Edge(
sid=value.nodes[0]._properties.get("id", ""),
tid=value.nodes[1]._properties.get("id", ""),
name=value._properties.get("name", ""),
**filter_properties(value._properties, _white_list),
)
if edge not in edge_list:
edge_list.append(edge)
elif isinstance(value, graph.Path):
for rel in value.relationships:
for node in rel.nodes: # num of nodes is 2
assert node and node._properties
vertex = Vertex(
vid=node._properties.get("id", ""),
name=node._properties.get("name"),
**filter_properties(node._properties, _white_list),
)
if vertex not in vertex_list:
vertex_list.append(vertex)
assert rel.nodes and rel.nodes[0] and rel.nodes[1]
edge = Edge(
sid=rel.nodes[0]._properties.get("id", ""),
tid=rel.nodes[1]._properties.get("id", ""),
name=rel._properties.get("name", ""),
**filter_properties(rel._properties, _white_list),
)
if edge not in edge_list:
edge_list.append(edge)
else: # json_node
vertex = Vertex(
vid="json_node",
name="json_node",
**filter_properties({"description": value}, _white_list),
)
if vertex not in vertex_list:
vertex_list.append(vertex)
return vertex_list, edge_list
def _parser(self, entity_list: List[Dict[str, Any]]) -> str:
"""Parse entities to string."""
formatted_nodes = [
"{"
+ ", ".join(
f'{k}: "{v}"' if isinstance(v, str) else f"{k}: {v}"
for k, v in node.items()
)
+ "}"
for node in entity_list
]
return f"""{", ".join(formatted_nodes)}"""

View File

@@ -2,14 +2,16 @@
import logging
import os
import uuid
from typing import List, Optional
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
from dbgpt.storage.graph_store.graph import GraphElemType, MemoryGraph
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt.storage.knowledge_graph.community.factory import CommunityStoreAdapterFactory
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
from dbgpt.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraph,
BuiltinKnowledgeGraphConfig,
@@ -27,7 +29,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
model_config = ConfigDict(arbitrary_types_allowed=True)
vector_store_type: str = Field(
default="Chroma", description="The type of vector store."
default="Chroma",
description="The type of vector store.",
)
user: Optional[str] = Field(
default=None,
@@ -36,7 +39,8 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
password: Optional[str] = Field(
default=None,
description=(
"The password of vector store, if not set, will use the default password."
"The password of vector store, "
"if not set, will use the default password."
),
)
extract_topk: int = Field(
@@ -120,7 +124,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
cfg.score_threshold = self._community_score_threshold
self._community_store = CommunityStore(
CommunityStoreAdapterFactory.create(self._graph_store),
GraphStoreAdapterFactory.create(self._graph_store),
CommunitySummarizer(self._llm_client, self._model_name),
VectorStoreFactory.create(
self._vector_store_type,
@@ -135,21 +139,165 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist graph."""
# todo add doc node
for chunk in chunks:
# todo add chunk node
# todo add relation doc-chunk
data_list = self._parse_chunks(chunks) # parse the chunks by def _lod_doc_graph
graph_of_all = MemoryGraph()
# extract graphs and save
graphs = await self._graph_extractor.extract(chunk.content)
for graph in graphs:
self._graph_store.insert_graph(graph)
# Support graph search by the document and the chunks
if self._graph_store.get_config().enable_document_graph:
doc_vid = str(uuid.uuid4())
doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node")
for chunk_index, chunk in enumerate(data_list):
if chunk["parent_id"] != "document":
# chunk -> include -> chunk
graph_of_all.upsert_vertex_and_edge(
src_vid=chunk["parent_id"],
src_name=chunk["parent_title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value,
)
else:
# document -> include -> chunk
graph_of_all.upsert_vertex_and_edge(
src_vid=doc_vid,
src_name=doc_name,
src_props={
"vertex_type": GraphElemType.DOCUMENT.value,
"content": "",
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value,
)
# chunk -> next -> chunk
if chunk_index >= 1:
graph_of_all.upsert_vertex_and_edge(
src_vid=data_list[chunk_index - 1]["id"],
src_name=data_list[chunk_index - 1]["title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": data_list[chunk_index - 1]["content"],
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.NEXT.value,
edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value,
)
# Support knowledge graph search by the entities and the relationships
if self._graph_store.get_config().enable_triplet_graph:
for chunk_index, chunk in enumerate(data_list):
# TODO: Use asyncio to extract graph to accelerate the process
# (attention to the CAP of the graph db)
graphs: List[MemoryGraph] = await self._graph_extractor.extract(
chunk["content"]
)
for graph in graphs:
graph_of_all.upsert_graph(graph)
# chunk -> include -> entity
if self._graph_store.get_config().enable_document_graph:
for vertex in graph.vertices():
graph_of_all.upsert_vertex_and_edge(
src_vid=chunk["id"],
src_name=chunk["title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
dst_vid=vertex.vid,
dst_name=vertex.name,
dst_props={
"vertex_type": GraphElemType.ENTITY.value,
"description": vertex.props.get("description", ""),
}, # note: description is only used for the entity
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value,
)
self._graph_store_apdater.upsert_graph(graph_of_all)
# use asyncio.gather
# tasks = [self._graph_extractor.extract(chunk.content) for chunk in chunks]
# results = await asyncio.gather(*tasks)
# for result in results:
# self._graph_store_apdater.upsert_graph(result[0])
# build communities and save
await self._community_store.build_communities()
return [chunk.chunk_id for chunk in chunks]
def _parse_chunks(slef, chunks: List[Chunk]):
"""Parse the chunks by anlyzing the markdown chunks."""
# TODO: Need to refact.
data = []
for chunk_index, chunk in enumerate(chunks):
parent = None
directory_keys = list(chunk.metadata.keys())[:-1]
parent_level = directory_keys[-2] if len(directory_keys) > 1 else None
current_level = directory_keys[-1] if directory_keys else "Header0"
chunk_data = {
"id": chunk.chunk_id,
"title": chunk.metadata.get(current_level, "none_header_chunk"),
"directory_keys": directory_keys,
"level": current_level,
"content": chunk.content,
"parent_id": None,
"parent_title": None,
"type": "chunk",
"chunk_index": chunk_index,
}
# Find the parent chunk
if parent_level:
for parent_direct in reversed(directory_keys[:-1]):
parent_titile = chunk.metadata.get(parent_direct, None)
for n in range(chunk_index - 1, -1, -1):
metadata = chunks[n].metadata
keys = list(metadata.keys())[:-1]
if (
metadata
and parent_direct == keys[-1]
and parent_titile == metadata.get(parent_direct)
):
parent = chunks[n]
chunk_data["parent_id"] = parent.chunk_id
chunk_data["parent_title"] = parent_titile
break
if chunk_index - n > len(directory_keys):
break
if chunk_data["parent_id"]:
break
if not chunk_data["parent_id"]:
chunk_data["parent_id"] = "document"
data.append(chunk_data)
return data
async def asimilar_search_with_scores(
self,
text,
@@ -158,7 +306,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve relevant community summaries."""
# global search: retrieve relevant community summaries
# Global search: retrieve relevant community summaries
communities = await self._community_store.search_communities(text)
summaries = [
f"Section {i + 1}:\n{community.summary}"
@@ -166,16 +314,53 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
]
context = "\n".join(summaries) if summaries else ""
# local search: extract keywords and explore subgraph
keywords = await self._keyword_extractor.extract(text)
subgraph = self._graph_store.explore(keywords, limit=topk).format()
logger.info(f"Search subgraph from {len(keywords)} keywords")
keywords: List[str] = await self._keyword_extractor.extract(text)
if not summaries and not subgraph:
# Local search: extract keywords and explore subgraph
subgraph = MemoryGraph()
subgraph_for_doc = MemoryGraph()
enable_triplet_graph = self._graph_store.get_config().enable_triplet_graph
enable_document_graph = self._graph_store.get_config().enable_document_graph
if enable_triplet_graph:
subgraph: MemoryGraph = self._graph_store_apdater.explore(
subs=keywords, limit=10, search_scope="knowledge_graph"
)
if enable_document_graph:
keywords_for_document_graph = keywords
for vertex in subgraph.vertices():
keywords_for_document_graph.append(vertex.name)
subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords_for_document_graph,
limit=5,
search_scope="document_graph",
)
else:
if enable_document_graph:
subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords,
limit=10,
search_scope="document_graph",
)
knowledge_graph_str = subgraph.format()
knowledge_graph_for_doc_str = subgraph_for_doc.format()
logger.info(f"Search subgraph from the following keywords:\n{len(keywords)}")
if not (summaries or knowledge_graph_str or knowledge_graph_for_doc_str):
return []
# merge search results into context
content = HYBRID_SEARCH_PT_CN.format(context=context, graph=subgraph)
content = HYBRID_SEARCH_PT_CN.format(
context=context,
knowledge_graph=knowledge_graph_str,
knowledge_graph_for_doc=knowledge_graph_for_doc_str,
)
logger.info(f"Final GraphRAG queried prompt:\n{content}")
return [Chunk(content=content)]
def truncate(self) -> List[str]:
@@ -200,174 +385,179 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
self._graph_extractor.drop()
HYBRID_SEARCH_PT_CN = (
"## 角色\n"
"你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息"
"准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息。"
"\n"
"## 技能\n"
"### 技能 1: 上下文理解\n"
"- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。\n"
"- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。\n"
"- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。"
"### 技能 2: 知识图谱理解\n"
"- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息"
"和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:\n"
"```"
"* 实体信息格式:\n"
"- (实体名)\n"
"- (实体名:实体描述)\n"
"- (实体名:实体属性表)\n"
"- (文本块ID:文档块内容)\n"
"- (目录ID:目录名)\n"
"- (文档ID:文档名称)\n"
"\n"
"* 关系信息的格式:\n"
"- (来源实体名)-[关系名]->(目标实体名)\n"
"- (来源实体名)-[关系名:关系描述]->(目标实体名)\n"
"- (来源实体名)-[关系名:关系属性表]->(目标实体名)\n"
"- (文本块实体)-[包含]->(实体名)\n"
"- (目录ID)-[包含]->(文本块实体)\n"
"- (目录ID)-[包含]->(子目录ID)\n"
"- (文档ID)-[包含]->(文本块实体)\n"
"- (文档ID)-[包含]->(目录ID)\n"
"```"
"- 正确地将关系信息中的实体名/ID与实体信息关联还原出图结构。"
"- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。\n"
"\n"
"## 约束条件\n"
"- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。\n"
"- 若[知识图谱]没有提供信息,此时应根据[上下文]提供的信息回答问题。"
"- 确保以第三人称书写,从客观角度结合[上下文]和[知识图谱]表达的信息回答问题。\n"
"- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。\n"
"- 避免使用停用词和过于常见的词汇。\n"
"\n"
"## 参考案例\n"
"```\n"
"[上下文]:\n"
"Section 1:\n"
"菲尔・贾伯的大儿子叫雅各布・贾伯。\n"
"Section 2:\n"
"菲尔・贾伯的小儿子叫比尔・贾伯。\n"
"[知识图谱]:\n"
"Entities:\n"
"(菲尔・贾伯#菲尔兹咖啡创始人)\n"
"(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)\n"
"(雅各布・贾伯#菲尔・贾伯的儿子)\n"
"(美国多地#菲尔兹咖啡的扩展地区)\n"
"\n"
"Relationships:\n"
"(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)\n"
"(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)\n"
"(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)\n"
"(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)\n"
"(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)\n"
"```\n"
"\n"
"----\n"
"\n"
"接下来的[上下文]和[知识图谱]的信息,可以帮助你回答更好地用户的问题。\n"
"\n"
"[上下文]:\n"
"{context}\n"
"\n"
"[知识图谱]:\n"
"{graph}\n"
"\n"
)
HYBRID_SEARCH_PT_CN = """## 角色
你非常擅长结合提示词模板提供的[上下文]信息与[知识图谱]信息,
准确恰当地回答用户的问题,并保证不会输出与上下文和知识图谱无关的信息
HYBRID_SEARCH_PT_EN = (
"## Role\n"
"You excel at combining the information provided in the [Context] with "
"information from the [KnowledgeGraph] to accurately and appropriately "
"answer user questions, ensuring that you do not output information "
"unrelated to the context and knowledge graph.\n"
"\n"
"## Skills\n"
"### Skill 1: Context Understanding\n"
"- Accurately understand the information provided in the [Context], "
"which may be divided into several sections.\n"
"- Each section in the context will start with [Section] "
"and may be numbered as needed.\n"
"- The context provides a summary description most relevant to the users "
"question, and it should be used wisely."
"### Skill 2: Knowledge Graph Understanding\n"
"- Accurately identify entity information in the [Entities:] section and "
"relationship information in the [Relationships:] section "
"of the [KnowledgeGraph]. The general format for entity "
"and relationship information is:\n"
"```"
"* Entity Information Format:\n"
"- (entity_name)\n"
"- (entity_name: entity_description)\n"
"- (entity_name: entity_property_map)\n"
"- (chunk_id: chunk_content)\n"
"- (catalog_id: catalog_name)\n"
"- (document_id: document_name)\n"
"\n"
"* Relationship Information Format:\n"
"- (source_entity_name)-[relationship_name]->(target_entity_name)\n"
"- (source_entity_name)-[relationship_name: relationship_description]->"
"(target_entity_name)\n"
"- (source_entity_name)-[relationship_name: relationship_property_map]->"
"(target_entity_name)\n"
"- (chunk_id)-[Contains]->(entity_name)\n"
"- (catalog_id)-[Contains]->(chunk_id)\n"
"- (catalog_id)-[Contains]->(sub_catalog_id)\n"
"- (document_id)-[Contains]->(chunk_id)\n"
"- (document_id)-[Contains]->(catalog_id)\n"
"```"
"- Correctly associate entity names/IDs in the relationship information "
"with entity information to restore the graph structure."
"- Use the information expressed by the graph structure as detailed "
"context for the user's query to assist in generating better answers.\n"
"\n"
"## Constraints\n"
"- Don't describe your thought process in the answer, provide the answer "
"to the user's question directly without generating irrelevant information."
"- If the [KnowledgeGraph] does not provide information, you should answer "
"the question based on the information provided in the [Context]."
"- Ensure to write in the third person, responding to questions from "
"an objective perspective based on the information combined from the "
"[Context] and the [KnowledgeGraph].\n"
"- If the provided information is contradictory, resolve the "
"contradictions and provide a single, coherent description.\n"
"- Avoid using stop words and overly common vocabulary.\n"
"\n"
"## Reference Example\n"
"```\n"
"[Context]:\n"
"Section 1:\n"
"Phil Schiller's eldest son is Jacob Schiller.\n"
"Section 2:\n"
"Phil Schiller's youngest son is Bill Schiller.\n"
"[KnowledgeGraph]:\n"
"Entities:\n"
"(Phil Jaber#Founder of Philz Coffee)\n"
"(Philz Coffee#Coffee brand founded in Berkeley, California)\n"
"(Jacob Jaber#Son of Phil Jaber)\n"
"(Multiple locations in the USA#Expansion regions of Philz Coffee)\n"
"\n"
"Relationships:\n"
"(Phil Jaber#Created#Philz Coffee"
"#Founded in Berkeley, California in 1978)\n"
"(Philz Coffee#Located in#Berkeley, California"
"#Founding location of Philz Coffee)\n"
"(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)\n"
"(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)\n"
"(Philz Coffee#Expanded to#Multiple locations in the USA"
"#Expansion regions of Philz Coffee)\n"
"```\n"
"\n"
"----\n"
"\n"
"The following information from the [Context] and [KnowledgeGraph] can "
"help you better answer user questions.\n"
"\n"
"[Context]:\n"
"{context}\n"
"\n"
"[KnowledgeGraph]:\n"
"{graph}\n"
"\n"
)
## 技能
### 技能 1: 上下文理解
- 准确地理解[上下文]提供的信息,上下文信息可能被拆分为多个章节。
- 上下文的每个章节内容都会以[Section]开始,并按需进行了编号。
- 上下文信息提供了与用户问题相关度最高的总结性描述,请合理使用它们。
### 技能 2: 知识图谱理解
- 准确地识别[知识图谱]中提供的[Entities:]章节中的实体信息和[Relationships:]章节中的关系信息,实体和关系信息的一般格式为:
```
* 实体信息格式:
- (实体名)
- (实体名:实体描述)
- (实体名:实体属性表)
- (文本块ID:文档块内容)
- (目录ID:目录名)
- (文档ID:文档名称)
* 关系信息的格式:
- (来源实体名)-[关系名]->(目标实体名)
- (来源实体名)-[关系名:关系描述]->(目标实体名)
- (来源实体名)-[关系名:关系属性表]->(目标实体名)
- (文本块实体)-[包含]->(实体名)
- (目录ID)-[包含]->(文本块实体)
- (目录ID)-[包含]->(子目录ID)
- (文档ID)-[包含]->(文本块实体)
- (文档ID)-[包含]->(目录ID)
```
- 正确地将关系信息中的实体名/ID与实体信息关联还原出图结构。
- 将图结构所表达的信息作为用户提问的明细上下文,辅助生成更好的答案。
## 约束条件
- 不要在答案中描述你的思考过程,直接给出用户问题的答案,不要生成无关信息。
- 若[知识图谱]或者[知识库原文]没有提供信息,此时应根据[上下文]提供的信息回答问题。
- 确保以第三人称书写,从客观角度结合[上下文]、[知识图谱]和[知识库原文]表达的信息回答问题。
- 若提供的信息相互矛盾,请解决矛盾并提供一个单一、连贯的描述。
- 避免使用停用词和过于常见的词汇。
## 参考案例
```
[上下文]:
Section 1:
菲尔・贾伯的大儿子叫雅各布・贾伯。
Section 2:
菲尔・贾伯的小儿子叫比尔・贾伯。
[知识图谱]:
Entities:
(菲尔・贾伯#菲尔兹咖啡创始人)
(菲尔兹咖啡#加利福尼亚州伯克利创立的咖啡品牌)
(雅各布・贾伯#菲尔・贾伯的儿子)
(美国多地#菲尔兹咖啡的扩展地区)
Relationships:
(菲尔・贾伯#创建#菲尔兹咖啡#1978年在加利福尼亚州伯克利创立)
(菲尔兹咖啡#位于#加利福尼亚州伯克利#菲尔兹咖啡的创立地点)
(菲尔・贾伯#拥有#雅各布・贾伯#菲尔・贾伯的儿子)
(雅各布・贾伯#担任#首席执行官#在2005年成为菲尔兹咖啡的首席执行官)
(菲尔兹咖啡#扩展至#美国多地#菲尔兹咖啡的扩展范围)
[知识库原文]:
...
```
----
接下来的[上下文]、[知识图谱]和[知识库原文]的信息,可以帮助你回答更好地用户的问题。
[上下文]:
{context}
[知识图谱]:
{knowledge_graph}
[知识库原文]
{knowledge_graph_for_doc}
""" # noqa: E501
HYBRID_SEARCH_PT_EN = """## Role
You excel at combining the information provided in the [Context] with
information from the [KnowledgeGraph] to accurately and appropriately
answer user questions, ensuring that you do not output information
unrelated to the context and knowledge graph.
## Skills
### Skill 1: Context Understanding
- Accurately understand the information provided in the [Context],
which may be divided into several sections.
- Each section in the context will start with [Section]
and may be numbered as needed.
- The context provides a summary description most relevant to the user's
question, and it should be used wisely.
### Skill 2: Knowledge Graph Understanding
- Accurately identify entity information in the [Entities:] section and
relationship information in the [Relationships:] section
of the [KnowledgeGraph]. The general format for entity
and relationship information is:
```
* Entity Information Format:
- (entity_name)
- (entity_name: entity_description)
- (entity_name: entity_property_map)
- (chunk_id: chunk_content)
- (catalog_id: catalog_name)
- (document_id: document_name)
* Relationship Information Format:
- (source_entity_name)-[relationship_name]->(target_entity_name)
- (source_entity_name)-[relationship_name: relationship_description]->(target_entity_name)
- (source_entity_name)-[relationship_name: relationship_property_map]->(target_entity_name)
- (chunk_id)-[Contains]->(entity_name)
- (catalog_id)-[Contains]->(chunk_id)
- (catalog_id)-[Contains]->(sub_catalog_id)
- (document_id)-[Contains]->(chunk_id)
- (document_id)-[Contains]->(catalog_id)
```
- Correctly associate entity names/IDs in the relationship information
with entity information to restore the graph structure.
- Use the information expressed by the graph structure as detailed
context for the user's query to assist in generating better answers.
## Constraints
- Don't describe your thought process in the answer, provide the answer
to the user's question directly without generating irrelevant information.
- If the [KnowledgeGraph] or [Knowledge base original text] does not provide information, you should answer
the question based on the information provided in the [Context].
- Ensure to write in the third person, responding to questions from
an objective perspective based on the information combined from the
[Context], the [KnowledgeGraph] and the [Knowledge base original text].
- If the provided information is contradictory, resolve the
contradictions and provide a single, coherent description.
- Avoid using stop words and overly common vocabulary.
## Reference Example
```
[Context]:
Section 1:
Phil Schiller's eldest son is Jacob Schiller.
Section 2:
Phil Schiller's youngest son is Bill Schiller.
[KnowledgeGraph]:
Entities:
(Phil Jaber#Founder of Philz Coffee)
(Philz Coffee#Coffee brand founded in Berkeley, California)
(Jacob Jaber#Son of Phil Jaber)
(Multiple locations in the USA#Expansion regions of Philz Coffee)
Relationships:
(Phil Jaber#Created#Philz Coffee#Founded in Berkeley, California in 1978)
(Philz Coffee#Located in#Berkeley, California#Founding location of Philz Coffee)
(Phil Jaber#Has#Jacob Jaber#Son of Phil Jaber)
(Jacob Jaber#Serves as#CEO#Became CEO of Philz Coffee in 2005)
(Philz Coffee#Expanded to#Multiple locations in the USA#Expansion regions of Philz Coffee)
[Knowledge base original text]
...
```
----
The following information from the [Context], [KnowledgeGraph] and [Knowledge base original text]
can help you better answer user questions.
[Context]:
{context}
[KnowledgeGraph]:
{knowledge_graph}
[Knowledge base original text]
{knowledge_graph_for_doc}
""" # noqa: E501

View File

@@ -1,4 +1,5 @@
"""Knowledge graph class."""
import asyncio
import logging
import os
@@ -12,6 +13,8 @@ from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig
from dbgpt.storage.graph_store.factory import GraphStoreFactory
from dbgpt.storage.graph_store.graph import Graph
from dbgpt.storage.knowledge_graph.base import KnowledgeGraphBase, KnowledgeGraphConfig
from dbgpt.storage.knowledge_graph.community.base import GraphStoreAdapter
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
from dbgpt.storage.vector_store.filters import MetadataFilters
logger = logging.getLogger(__name__)
@@ -46,9 +49,10 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
self._model_name = config.model_name
self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name)
self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name)
self._graph_store = self.__init_graph_store(config)
self._graph_store: GraphStoreBase = self.__init_graph_store(config)
self._graph_store_apdater: GraphStoreAdapter = self.__init_graph_store_adapter()
def __init_graph_store(self, config) -> GraphStoreBase:
def __init_graph_store(self, config: BuiltinKnowledgeGraphConfig) -> GraphStoreBase:
def configure(cfg: GraphStoreConfig):
cfg.name = config.name
cfg.embedding_fn = config.embedding_fn
@@ -56,6 +60,9 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
graph_store_type = os.getenv("GRAPH_STORE_TYPE") or config.graph_store_type
return GraphStoreFactory.create(graph_store_type, configure)
def __init_graph_store_adapter(self):
return GraphStoreAdapterFactory.create(self._graph_store)
def get_config(self) -> BuiltinKnowledgeGraphConfig:
"""Get the knowledge graph config."""
return self._config
@@ -63,10 +70,10 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist triplets to graph store."""
async def process_chunk(chunk):
async def process_chunk(chunk: Chunk):
triplets = await self._triplet_extractor.extract(chunk.content)
for triplet in triplets:
self._graph_store.insert_triplet(*triplet)
self._graph_store_apdater.insert_triplet(*triplet)
logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}")
return chunk.chunk_id
@@ -89,7 +96,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
for chunk in chunks:
triplets = await self._triplet_extractor.extract(chunk.content)
for triplet in triplets:
self._graph_store.insert_triplet(*triplet)
self._graph_store_apdater.insert_triplet(*triplet)
logger.info(f"load {len(triplets)} triplets from chunk {chunk.chunk_id}")
return [chunk.chunk_id for chunk in chunks]
@@ -116,7 +123,8 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
# extract keywords and explore graph store
keywords = await self._keyword_extractor.extract(text)
subgraph = self._graph_store.explore(keywords, limit=topk).format()
subgraph = self._graph_store_apdater.explore(keywords, limit=topk).format()
logger.info(f"Search subgraph from {len(keywords)} keywords")
if not subgraph:
@@ -147,12 +155,12 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def query_graph(self, limit: Optional[int] = None) -> Graph:
"""Query graph."""
return self._graph_store.get_full_graph(limit)
return self._graph_store_apdater.get_full_graph(limit)
def truncate(self) -> List[str]:
"""Truncate knowledge graph."""
logger.info(f"Truncate graph {self._config.name}")
self._graph_store.truncate()
self._graph_store_apdater.truncate()
logger.info("Truncate keyword extractor")
self._keyword_extractor.truncate()
@@ -165,10 +173,15 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def delete_vector_name(self, index_name: str):
"""Delete vector name."""
logger.info(f"Drop graph {index_name}")
self._graph_store.drop()
self._graph_store_apdater.drop()
logger.info("Drop keyword extractor")
self._keyword_extractor.drop()
logger.info("Drop triplet extractor")
self._triplet_extractor.drop()
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete by ids."""
self._graph_store_apdater.delete_document(chunk_ids=ids)
return []

View File

@@ -1,4 +1,5 @@
"""The database manager."""
from __future__ import annotations
import logging

View File

@@ -10,7 +10,7 @@ You can refer to the python example file `DB-GPT/examples/rag/graph_rag_example.
First, you need to install the `dbgpt` library.
```bash
pip install "dbgpt[rag]>=0.6.0"
pip install "dbgpt[graph_rag]>=0.6.1"
````
### Prepare Graph Database
@@ -112,7 +112,9 @@ TUGRAPH_HOST=127.0.0.1
TUGRAPH_PORT=7687
TUGRAPH_USERNAME=admin
TUGRAPH_PASSWORD=73@TuGraph
GRAPH_COMMUNITY_SUMMARY_ENABLED=True
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
ENABLE_TRIPLET_GRAPH=True # enable the graph search for the triplets
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks
```
@@ -250,23 +252,23 @@ Performance testing is based on the `gpt-4o-mini` model.
#### Indexing Performance
| | DB-GPT | GraphRAG(microsoft) |
|----------|----------|------------------------|
| Document Tokens | 42631 | 42631 |
| Graph Size | 808 nodes, 1170 edges | 779 nodes, 967 edges |
| Prompt Tokens | 452614 | 744990 |
| Completion Tokens | 48325 | 227230 |
| Total Tokens | 500939 | 972220 |
| | DB-GPT | GraphRAG(microsoft) |
| ----------------- | --------------------- | -------------------- |
| Document Tokens | 42631 | 42631 |
| Graph Size | 808 nodes, 1170 edges | 779 nodes, 967 edges |
| Prompt Tokens | 452614 | 744990 |
| Completion Tokens | 48325 | 227230 |
| Total Tokens | 500939 | 972220 |
#### Querying Performance
**Global Search**
| | DB-GPT | GraphRAG(microsoft) |
|----------|----------|------------------------|
| Time | 8s | 40s |
| Tokens| 7432 | 63317 |
| | DB-GPT | GraphRAG(microsoft) |
| ------ | ------ | ------------------- |
| Time | 8s | 40s |
| Tokens | 7432 | 63317 |
**Question**
```
@@ -304,10 +306,10 @@ Performance testing is based on the `gpt-4o-mini` model.
**Local Search**
| | DB-GPT | GraphRAG(microsoft) |
|----------|----------|------------------------|
| Time | 15s | 15s |
| Tokens| 9230 | 11619 |
| | DB-GPT | GraphRAG(microsoft) |
| ------ | ------ | ------------------- |
| Time | 15s | 15s |
| Tokens | 9230 | 11619 |
**Question**
@@ -352,3 +354,28 @@ DB-GPT社区与TuGraph社区的比较
总结
总体而言DB-GPT社区和TuGraph社区在社区贡献、生态系统和开发者参与等方面各具特色。DB-GPT社区更侧重于AI应用的多样性和组织间的合作而TuGraph社区则专注于图数据的高效管理和分析。两者的共同点在于都强调了开源和社区合作的重要性推动了各自领域的技术进步和应用发展。
```
### Latest Updates
In version 0.6.1 of DB-GPT, we have added a new feature:
- Retrieval of triplets with the **retrieval of document structure**
We have expanded the definition scope of 'Graph' in GraphRAG:
```
Knowledge Graph = Triplets Graph + Document Structure Graph
```
<p align="left">
<img src={'/img/chat_knowledge/graph_rag/image_graphrag_0_6_1.png'} width="1000px"/>
</p>
How?
We decompose standard format files (currently best support for Markdown files) into a directed graph based on their hierarchy and layout information, and store it in a graph database. In this graph:
- Each node represents a chunk of the file
- Each edge represents the structural relationship between different chunks in the original document
- Merge the document structure graph to the triplets graph
What is the next?
We aim to construct a more complex Graph that covers more comprehensive information to support more sophisticated retrieval algorithms in our GraphRAG.

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

View File

@@ -1,4 +1,5 @@
import pytest
from dbgpt.datasource.conn_tugraph import TuGraphConnector
# Set database connection parameters.
@@ -21,10 +22,10 @@ def connector():
def test_get_table_names(connector):
"""Test retrieving table names from the graph database."""
table_names = connector.get_table_names()
vertex_tables, edge_tables = connector.get_table_names()
# Verify the quantity of vertex and edge tables.
assert len(table_names["vertex_tables"]) == 5
assert len(table_names["edge_tables"]) == 8
assert len(vertex_tables) == 5
assert len(edge_tables) == 8
def test_get_columns(connector):

View File

@@ -4,6 +4,9 @@ from dbgpt.storage.graph_store.memgraph_store import (
MemoryGraphStore,
MemoryGraphStoreConfig,
)
from dbgpt.storage.knowledge_graph.community.memgraph_store_adapter import (
MemGraphStoreAdapter,
)
@pytest.fixture
@@ -11,31 +14,37 @@ def graph_store():
yield MemoryGraphStore(MemoryGraphStoreConfig())
def test_graph_store(graph_store):
graph_store.insert_triplet("A", "0", "A")
graph_store.insert_triplet("A", "1", "A")
graph_store.insert_triplet("A", "2", "B")
graph_store.insert_triplet("B", "3", "C")
graph_store.insert_triplet("B", "4", "D")
graph_store.insert_triplet("C", "5", "D")
graph_store.insert_triplet("B", "6", "E")
graph_store.insert_triplet("F", "7", "E")
graph_store.insert_triplet("E", "8", "F")
@pytest.fixture
def graph_store_adapter(graph_store: MemoryGraphStore):
memgraph_store_adapter = MemGraphStoreAdapter(graph_store)
yield memgraph_store_adapter
subgraph = graph_store.explore(["A"])
def test_graph_store(graph_store_adapter: MemGraphStoreAdapter):
graph_store_adapter.insert_triplet("A", "0", "A")
graph_store_adapter.insert_triplet("A", "1", "A")
graph_store_adapter.insert_triplet("A", "2", "B")
graph_store_adapter.insert_triplet("B", "3", "C")
graph_store_adapter.insert_triplet("B", "4", "D")
graph_store_adapter.insert_triplet("C", "5", "D")
graph_store_adapter.insert_triplet("B", "6", "E")
graph_store_adapter.insert_triplet("F", "7", "E")
graph_store_adapter.insert_triplet("E", "8", "F")
subgraph = graph_store_adapter.explore(["A"])
print(f"\n{subgraph.format()}")
assert subgraph.edge_count == 9
graph_store.delete_triplet("A", "0", "A")
graph_store.delete_triplet("B", "4", "D")
subgraph = graph_store.explore(["A"])
graph_store_adapter.delete_triplet("A", "0", "A")
graph_store_adapter.delete_triplet("B", "4", "D")
subgraph = graph_store_adapter.explore(["A"])
print(f"\n{subgraph.format()}")
assert subgraph.edge_count == 7
triplets = graph_store.get_triplets("B")
triplets = graph_store_adapter.get_triplets("B")
print(f"\nTriplets of B: {triplets}")
assert len(triplets) == 2
schema = graph_store.get_schema()
schema = graph_store_adapter.get_schema()
print(f"\nSchema: {schema}")
assert len(schema) == 86

View File

@@ -1,43 +1,52 @@
# test_tugraph_store.py
# test_tugraph_tugraph_store_adapter.py
import pytest
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
TuGraphStoreAdapter,
)
@pytest.fixture(scope="module")
def store():
config = TuGraphStoreConfig(name="TestGraph", summary_enabled=False)
config = TuGraphStoreConfig(name="TestGraph", enable_summary=False)
store = TuGraphStore(config=config)
yield store
store.conn.close()
def test_insert_and_get_triplets(store):
store.insert_triplet("A", "0", "A")
store.insert_triplet("A", "1", "A")
store.insert_triplet("A", "2", "B")
store.insert_triplet("B", "3", "C")
store.insert_triplet("B", "4", "D")
store.insert_triplet("C", "5", "D")
store.insert_triplet("B", "6", "E")
store.insert_triplet("F", "7", "E")
store.insert_triplet("E", "8", "F")
triplets = store.get_triplets("A")
@pytest.fixture(scope="module")
def tugraph_store_adapter(store: TuGraphStore):
tugraph_store_adapter = TuGraphStoreAdapter(store)
yield tugraph_store_adapter
def test_insert_and_get_triplets(tugraph_store_adapter: TuGraphStoreAdapter):
tugraph_store_adapter.insert_triplet("A", "0", "A")
tugraph_store_adapter.insert_triplet("A", "1", "A")
tugraph_store_adapter.insert_triplet("A", "2", "B")
tugraph_store_adapter.insert_triplet("B", "3", "C")
tugraph_store_adapter.insert_triplet("B", "4", "D")
tugraph_store_adapter.insert_triplet("C", "5", "D")
tugraph_store_adapter.insert_triplet("B", "6", "E")
tugraph_store_adapter.insert_triplet("F", "7", "E")
tugraph_store_adapter.insert_triplet("E", "8", "F")
triplets = tugraph_store_adapter.get_triplets("A")
assert len(triplets) == 2
triplets = store.get_triplets("B")
triplets = tugraph_store_adapter.get_triplets("B")
assert len(triplets) == 3
triplets = store.get_triplets("C")
triplets = tugraph_store_adapter.get_triplets("C")
assert len(triplets) == 1
triplets = store.get_triplets("D")
triplets = tugraph_store_adapter.get_triplets("D")
assert len(triplets) == 0
triplets = store.get_triplets("E")
triplets = tugraph_store_adapter.get_triplets("E")
assert len(triplets) == 1
triplets = store.get_triplets("F")
triplets = tugraph_store_adapter.get_triplets("F")
assert len(triplets) == 1
def test_query(store):
def test_query(store: TuGraphStore):
query = "MATCH (n)-[r]->(n1) return n,n1,r limit 3"
result = store.query(query)
v_c = result.vertex_count
@@ -45,18 +54,18 @@ def test_query(store):
assert v_c == 3 and e_c == 3
def test_explore(store):
def test_explore(tugraph_store_adapter: TuGraphStoreAdapter):
subs = ["A", "B"]
result = store.explore(subs, depth=2, fan=None, limit=10)
result = tugraph_store_adapter.explore(subs, depth=2, fan=None, limit=10)
v_c = result.vertex_count
e_c = result.edge_count
assert v_c == 5 and e_c == 5
def test_delete_triplet(store):
def test_delete_triplet(tugraph_store_adapter: TuGraphStoreAdapter):
subj = "A"
rel = "0"
obj = "B"
store.delete_triplet(subj, rel, obj)
triplets = store.get_triplets(subj)
tugraph_store_adapter.delete_triplet(subj, rel, obj)
triplets = tugraph_store_adapter.get_triplets(subj)
assert len(triplets) == 0

View File

@@ -1,18 +1,27 @@
import pytest
from dbgpt.storage.graph_store.graph import Edge, MemoryGraph, Vertex
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore, TuGraphStoreConfig
from dbgpt.storage.graph_store.graph import MemoryGraph, Edge, Vertex
from dbgpt.storage.knowledge_graph.community.tugraph_store_adapter import (
TuGraphStoreAdapter,
)
@pytest.fixture(scope="module")
def store():
config = TuGraphStoreConfig(name="TestSummaryGraph", summary_enabled=True)
config = TuGraphStoreConfig(name="TestSummaryGraph", enable_summary=True)
store_instance = TuGraphStore(config=config)
yield store_instance
store_instance.conn.close()
def test_insert_graph(store):
@pytest.fixture(scope="module")
def graph_store_adapter(store: TuGraphStore):
tugraph_store_adapter = TuGraphStoreAdapter(store)
yield tugraph_store_adapter
def test_upsert_graph(tugraph_store_adapter: TuGraphStoreAdapter):
graph = MemoryGraph()
vertex_list = [
Vertex("A", "A", description="Vertex A", _document_id="Test doc"),
@@ -35,22 +44,22 @@ def test_insert_graph(store):
graph.upsert_vertex(vertex)
for edge in edge_list:
graph.append_edge(edge)
store.insert_graph(graph)
tugraph_store_adapter.upsert_graph(graph)
def test_leiden_query(store):
def test_leiden_query(store: TuGraphStore):
query = "CALL db.plugin.callPlugin('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
result = store.query(query)
assert result.vertex_count == 1
def test_query_node_and_edge(store):
def test_query_node_and_edge(store: TuGraphStore):
query = 'MATCH (n)-[r]->(m) WHERE n._community_id = "0" RETURN n,r,m'
result = store.query(query)
assert result.vertex_count == 7 and result.edge_count == 6
def test_stream_query_path(store):
def test_stream_query_path(store: TuGraphStore):
query = 'MATCH p=(n)-[r:relation*2]->(m) WHERE n._community_id = "0" RETURN p'
result = store.query(query)
for v in result.vertices():