refactor: the document structure of GraphRAG (#2084)

This commit is contained in:
Appointat 2024-10-21 18:05:55 +08:00 committed by GitHub
parent 584f090847
commit 6d6667812b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 401 additions and 284 deletions

View File

@ -162,9 +162,11 @@ KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20 KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0 KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
ENABLE_TRIPLET_GRAPH=True # enable the graph search for triplets TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
### Chroma vector db config ### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data #CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

View File

@ -213,8 +213,8 @@ class Config(metaclass=Singleton):
# Vector Store Configuration # Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma") self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.ENABLE_GRAPH_COMMUNITY_SUMMARY = ( self.GRAPH_COMMUNITY_SUMMARY_ENABLED = (
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true" os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
) )
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1") self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530") self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")

View File

@ -49,6 +49,7 @@ class Chunk(Document):
chunk_id: str = Field( chunk_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="unique id for the chunk" default_factory=lambda: str(uuid.uuid4()), description="unique id for the chunk"
) )
chunk_name: str = Field(default="", description="chunk name")
content: str = Field(default="", description="chunk text content") content: str = Field(default="", description="chunk text content")
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(

View File

@ -128,7 +128,7 @@ class VectorStoreConnector:
def __rewrite_index_store_type(self, index_store_type): def __rewrite_index_store_type(self, index_store_type):
# Rewrite Knowledge Graph Type # Rewrite Knowledge Graph Type
if CFG.ENABLE_GRAPH_COMMUNITY_SUMMARY: if CFG.GRAPH_COMMUNITY_SUMMARY_ENABLED:
if index_store_type == "KnowledgeGraph": if index_store_type == "KnowledgeGraph":
return "CommunitySummaryKnowledgeGraph" return "CommunitySummaryKnowledgeGraph"
return index_store_type return index_store_type

View File

@ -27,11 +27,11 @@ class GraphStoreConfig(BaseModel):
default=False, default=False,
description="Enable graph community summary or not.", description="Enable graph community summary or not.",
) )
enable_document_graph: bool = Field( document_graph_enabled: bool = Field(
default=True, default=True,
description="Enable document graph search or not.", description="Enable document graph search or not.",
) )
enable_triplet_graph: bool = Field( triplet_graph_enabled: bool = Field(
default=True, default=True,
description="Enable knowledge graph search or not.", description="Enable knowledge graph search or not.",
) )
@ -48,11 +48,3 @@ class GraphStoreBase(ABC):
@abstractmethod @abstractmethod
def get_config(self) -> GraphStoreConfig: def get_config(self) -> GraphStoreConfig:
"""Get the graph store config.""" """Get the graph store config."""
@abstractmethod
def _escape_quotes(self, text: str) -> str:
"""Escape single and double quotes in a string for queries."""
# @abstractmethod
# def _paser(self, entities: List[Vertex]) -> str:
# """Parse entities to string."""

View File

@ -19,8 +19,8 @@ class GraphElemType(Enum):
DOCUMENT = "document" DOCUMENT = "document"
CHUNK = "chunk" CHUNK = "chunk"
ENTITY = "entity" # view as general vertex in the general case ENTITY = "entity" # default vertex type in knowledge graph
RELATION = "relation" # view as general edge in the general case RELATION = "relation" # default edge type in knowledge graph
INCLUDE = "include" INCLUDE = "include"
NEXT = "next" NEXT = "next"
@ -39,7 +39,15 @@ class GraphElemType(Enum):
def is_edge(self) -> bool: def is_edge(self) -> bool:
"""Check if the element is an edge.""" """Check if the element is an edge."""
return not self.is_vertex() return self in [
GraphElemType.RELATION,
GraphElemType.INCLUDE,
GraphElemType.NEXT,
GraphElemType.DOCUMENT_INCLUDE_CHUNK,
GraphElemType.CHUNK_INCLUDE_CHUNK,
GraphElemType.CHUNK_INCLUDE_ENTITY,
GraphElemType.CHUNK_NEXT_CHUNK,
]
class Direction(Enum): class Direction(Enum):
@ -335,26 +343,6 @@ class MemoryGraph(Graph):
self._edge_count += 1 self._edge_count += 1
return True return True
def upsert_vertex_and_edge(
self,
src_vid: str,
src_name: str,
src_props: Dict[str, Any],
dst_vid: str,
dst_name: str,
dst_props: Dict[str, Any],
edge_name: str,
edge_type: str,
):
"""Uperst src and dst vertex, and edge."""
src_vertex = Vertex(src_vid, src_name, **src_props)
dst_vertex = Vertex(dst_vid, dst_name, **dst_props)
edge = Edge(src_vid, dst_vid, edge_name, **{"edge_type": edge_type})
self.upsert_vertex(src_vertex)
self.upsert_vertex(dst_vertex)
self.append_edge(edge)
def upsert_graph(self, graph: "MemoryGraph"): def upsert_graph(self, graph: "MemoryGraph"):
"""Upsert a graph.""" """Upsert a graph."""
for vertex in graph.vertices(): for vertex in graph.vertices():

View File

@ -26,9 +26,3 @@ class MemoryGraphStore(GraphStoreBase):
def get_config(self): def get_config(self):
"""Get the graph store config.""" """Get the graph store config."""
return self._graph_store_config return self._graph_store_config
def _escape_quotes(self, text: str) -> str:
"""Escape single and double quotes in a string for queries."""
raise NotImplementedError(
"_escape_quotes is not implemented by MemoryGraphStore"
)

View File

@ -80,16 +80,16 @@ class TuGraphStore(GraphStoreBase):
self._username = os.getenv("TUGRAPH_USERNAME", config.username) self._username = os.getenv("TUGRAPH_USERNAME", config.username)
self._password = os.getenv("TUGRAPH_PASSWORD", config.password) self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
self._enable_summary = ( self._enable_summary = (
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true" os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary or config.enable_summary
) )
self._enable_document_graph = ( self._enable_document_graph = (
os.getenv("ENABLE_DOCUMENT_GRAPH", "").lower() == "true" os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.enable_document_graph or config.document_graph_enabled
) )
self._enable_triplet_graph = ( self._enable_triplet_graph = (
os.getenv("ENABLE_TRIPLET_GRAPH", "").lower() == "true" os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.enable_triplet_graph or config.triplet_graph_enabled
) )
self._plugin_names = ( self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",") os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
@ -159,8 +159,3 @@ class TuGraphStore(GraphStoreBase):
f"'{name} Plugin', false, 'v1')" f"'{name} Plugin', false, 'v1')"
) )
self.conn.run(gql) self.conn.run(gql)
def _escape_quotes(self, value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")

View File

@ -4,7 +4,10 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
from pydantic import Field
from dbgpt._private.pydantic import ConfigDict from dbgpt._private.pydantic import ConfigDict
from dbgpt.core import Chunk
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
from dbgpt.storage.graph_store.graph import Graph from dbgpt.storage.graph_store.graph import Graph
@ -31,3 +34,14 @@ class KnowledgeGraphBase(IndexStoreBase, ABC):
@abstractmethod @abstractmethod
def delete_by_ids(self, ids: str) -> List[str]: def delete_by_ids(self, ids: str) -> List[str]:
"""Delete document by ids.""" """Delete document by ids."""
class ParagraphChunk(Chunk):
"""Loaded chunk, used in GraphRAG."""
chunk_parent_id: str = Field(default=None, description="id of parent chunk")
chunk_parent_name: str = Field(default=None, description="parent chunk name")
parent_content: str = Field(default=None, description="parent chunk text content")
parent_is_document: bool = Field(
default=False, description="is parent chunk a document"
)

View File

@ -3,7 +3,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncGenerator, Iterator, List, Optional from typing import AsyncGenerator, Iterator, List, Optional, Union
from dbgpt.storage.graph_store.base import GraphStoreBase from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import ( from dbgpt.storage.graph_store.graph import (
@ -14,6 +14,7 @@ from dbgpt.storage.graph_store.graph import (
MemoryGraph, MemoryGraph,
Vertex, Vertex,
) )
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,17 +92,17 @@ class GraphStoreAdapter(ABC):
"""Upsert edge.""" """Upsert edge."""
@abstractmethod @abstractmethod
def upsert_chunks(self, chunk: Iterator[Vertex]) -> None: def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert chunk.""" """Upsert chunk."""
@abstractmethod @abstractmethod
def upsert_documents(self, documents: Iterator[Vertex]) -> None: def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert documents.""" """Upsert documents."""
@abstractmethod
def upsert_relations(self, relations: Iterator[Edge]) -> None:
"""Upsert relations."""
@abstractmethod @abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str) -> None: def insert_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Insert triplet.""" """Insert triplet."""
@ -110,6 +111,34 @@ class GraphStoreAdapter(ABC):
def upsert_graph(self, graph: Graph) -> None: def upsert_graph(self, graph: Graph) -> None:
"""Insert graph.""" """Insert graph."""
@abstractmethod
def upsert_doc_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to document include chunk."""
@abstractmethod
def upsert_chunk_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to chunk include chunk."""
@abstractmethod
def upsert_chunk_next_chunk(
self,
chunk: ParagraphChunk,
next_chunk: ParagraphChunk,
):
"""Uperst the vertices and the edge in chunk_next_chunk."""
@abstractmethod
def upsert_chunk_include_entity(
self, chunk: ParagraphChunk, entity: Vertex
) -> None:
"""Convert chunk to chunk include entity."""
@abstractmethod @abstractmethod
def delete_document(self, chunk_id: str) -> None: def delete_document(self, chunk_id: str) -> None:
"""Delete document in graph store.""" """Delete document in graph store."""

View File

@ -2,7 +2,7 @@
import json import json
import logging import logging
from typing import AsyncGenerator, Iterator, List, Optional, Tuple from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
from dbgpt.storage.graph_store.graph import ( from dbgpt.storage.graph_store.graph import (
Direction, Direction,
@ -16,6 +16,7 @@ from dbgpt.storage.graph_store.memgraph_store import (
MemoryGraphStore, MemoryGraphStore,
MemoryGraphStoreConfig, MemoryGraphStoreConfig,
) )
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,11 +39,11 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
async def discover_communities(self, **kwargs) -> List[str]: async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden.""" """Run community discovery with leiden."""
pass []
async def get_community(self, community_id: str) -> Community: async def get_community(self, community_id: str) -> Community:
"""Get community.""" """Get community."""
pass raise NotImplementedError("Memory graph store does not have community")
def get_graph_config(self): def get_graph_config(self):
"""Get the graph store config.""" """Get the graph store config."""
@ -96,11 +97,15 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
"""Upsert edges.""" """Upsert edges."""
pass pass
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None: def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert chunks.""" """Upsert chunks."""
pass pass
def upsert_documents(self, documents: Iterator[Vertex]) -> None: def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert documents.""" """Upsert documents."""
pass pass
@ -108,6 +113,32 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
"""Upsert relations.""" """Upsert relations."""
pass pass
def upsert_doc_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to document include chunk."""
pass
def upsert_chunk_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to chunk include chunk."""
pass
def upsert_chunk_next_chunk(
self, chunk: ParagraphChunk, next_chunk: ParagraphChunk
):
"""Uperst the vertices and the edge in chunk_next_chunk."""
pass
def upsert_chunk_include_entity(
self, chunk: ParagraphChunk, entity: Vertex
) -> None:
"""Convert chunk to chunk include entity."""
pass
def insert_triplet(self, subj: str, rel: str, obj: str) -> None: def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet.""" """Add triplet."""
self._graph_store._graph.append_edge(Edge(subj, obj, rel)) self._graph_store._graph.append_edge(Edge(subj, obj, rel))

View File

@ -23,6 +23,7 @@ from dbgpt.storage.graph_store.graph import (
Vertex, Vertex,
) )
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -137,11 +138,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
"""Upsert entities.""" """Upsert entities."""
entity_list = [ entity_list = [
{ {
"id": self.graph_store._escape_quotes(entity.vid), "id": self._escape_quotes(entity.vid),
"name": self.graph_store._escape_quotes(entity.name), "name": self._escape_quotes(entity.name),
"description": self.graph_store._escape_quotes( "description": self._escape_quotes(entity.get_prop("description"))
entity.get_prop("description")
)
or "", or "",
"_document_id": "0", "_document_id": "0",
"_chunk_id": "0", "_chunk_id": "0",
@ -152,7 +151,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
entity_query = ( entity_query = (
f"CALL db.upsertVertex(" f"CALL db.upsertVertex("
f'"{GraphElemType.ENTITY.value}", ' f'"{GraphElemType.ENTITY.value}", '
f"[{self._parser(entity_list)}])" f"[{self._convert_dict_to_str(entity_list)}])"
) )
self.graph_store.conn.run(query=entity_query) self.graph_store.conn.run(query=entity_query)
@ -162,64 +161,81 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
"""Upsert edges.""" """Upsert edges."""
edge_list = [ edge_list = [
{ {
"sid": self.graph_store._escape_quotes(edge.sid), "sid": self._escape_quotes(edge.sid),
"tid": self.graph_store._escape_quotes(edge.tid), "tid": self._escape_quotes(edge.tid),
"id": self.graph_store._escape_quotes(edge.name), "id": self._escape_quotes(edge.name),
"name": self.graph_store._escape_quotes(edge.name), "name": self._escape_quotes(edge.name),
"description": self.graph_store._escape_quotes( "description": self._escape_quotes(edge.get_prop("description")) or "",
edge.get_prop("description") "_chunk_id": self._escape_quotes(edge.get_prop("_chunk_id")) or "",
)
or "",
"_chunk_id": self.graph_store._escape_quotes(edge.get_prop("_chunk_id"))
or "",
} }
for edge in edges for edge in edges
] ]
relation_query = f"""CALL db.upsertEdge("{edge_type}", relation_query = f"""CALL db.upsertEdge("{edge_type}",
{{type:"{src_type}", key:"sid"}}, {{type:"{src_type}", key:"sid"}},
{{type:"{dst_type}", key:"tid"}}, {{type:"{dst_type}", key:"tid"}},
[{self._parser(edge_list)}])""" [{self._convert_dict_to_str(edge_list)}])"""
self.graph_store.conn.run(query=relation_query) self.graph_store.conn.run(query=relation_query)
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None: def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert chunks.""" """Upsert chunks."""
chunk_list = [ chunks_list = list(chunks)
{ if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
"id": self.graph_store._escape_quotes(chunk.vid), chunk_list = [
"name": self.graph_store._escape_quotes(chunk.name), {
"content": self.graph_store._escape_quotes(chunk.get_prop("content")), "id": self._escape_quotes(chunk.chunk_id),
} "name": self._escape_quotes(chunk.chunk_name),
for chunk in chunks "content": self._escape_quotes(chunk.content),
] }
for chunk in chunks_list
]
else:
chunk_list = [
{
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks_list
]
chunk_query = ( chunk_query = (
f"CALL db.upsertVertex(" f"CALL db.upsertVertex("
f'"{GraphElemType.CHUNK.value}", ' f'"{GraphElemType.CHUNK.value}", '
f"[{self._parser(chunk_list)}])" f"[{self._convert_dict_to_str(chunk_list)}])"
) )
self.graph_store.conn.run(query=chunk_query) self.graph_store.conn.run(query=chunk_query)
def upsert_documents(self, documents: Iterator[Vertex]) -> None: def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert documents.""" """Upsert documents."""
document_list = [ documents_list = list(documents)
{ if documents_list and isinstance(documents_list[0], ParagraphChunk):
"id": self.graph_store._escape_quotes(document.vid), document_list = [
"name": self.graph_store._escape_quotes(document.name), {
"content": self.graph_store._escape_quotes(document.get_prop("content")) "id": self._escape_quotes(document.chunk_id),
or "", "name": self._escape_quotes(document.chunk_name),
} "content": "",
for document in documents }
] for document in documents_list
]
else:
document_list = [
{
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": self._escape_quotes(document.get_prop("content")) or "",
}
for document in documents_list
]
document_query = ( document_query = (
"CALL db.upsertVertex(" "CALL db.upsertVertex("
f'"{GraphElemType.DOCUMENT.value}", ' f'"{GraphElemType.DOCUMENT.value}", '
f"[{self._parser(document_list)}])" f"[{self._convert_dict_to_str(document_list)}])"
) )
self.graph_store.conn.run(query=document_query) self.graph_store.conn.run(query=document_query)
def upsert_relations(self, relations: Iterator[Edge]) -> None:
"""Upsert relations."""
pass
def insert_triplet(self, subj: str, rel: str, obj: str) -> None: def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet.""" """Add triplet."""
subj_escaped = subj.replace("'", "\\'").replace('"', '\\"') subj_escaped = subj.replace("'", "\\'").replace('"', '\\"')
@ -465,10 +481,6 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
"properties": graph_properties, "properties": graph_properties,
}) })
gql = f"""CALL db.createVertexLabelByJson('{data}')""" gql = f"""CALL db.createVertexLabelByJson('{data}')"""
gql_check_exist = (
f"""CALL db.getLabelSchema('VERTEX', '{graph_elem_type.value}')"""
)
else: # edge else: # edge
def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]: def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
@ -501,20 +513,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
}) })
gql = f"""CALL db.createEdgeLabelByJson('{data}')""" gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
gql_check_exist = ( self.graph_store.conn.run(gql)
f"""CALL db.getLabelSchema('EDGE', '{graph_elem_type.value}')"""
)
# Make sure the graph label is identical
try:
self.graph_store.conn.run(
gql_check_exist
) # if not exist, qurying raises an exception
except Exception:
self.graph_store.conn.run(gql) # create the graph label
return
logger.info(f"Graph label {graph_elem_type.value} already exists.")
def truncate(self): def truncate(self):
"""Truncate Graph.""" """Truncate Graph."""
@ -542,7 +541,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self, self,
subs: List[str], subs: List[str],
direct: Direction = Direction.BOTH, direct: Direction = Direction.BOTH,
depth: Optional[int] = None, depth: int = 3,
limit: Optional[int] = None, limit: Optional[int] = None,
search_scope: Optional[ search_scope: Optional[
Literal["knowledge_graph", "document_graph"] Literal["knowledge_graph", "document_graph"]
@ -552,10 +551,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
if not subs: if not subs:
return MemoryGraph() return MemoryGraph()
if depth is None or depth < 0 or depth > self.MAX_HIERARCHY_LEVEL: if depth < 0:
# TODO: to be discussed, be none or MAX_HIERARCHY_LEVEL depth = 3
# depth_string = ".."
depth = self.MAX_HIERARCHY_LEVEL
depth_string = f"1..{depth}" depth_string = f"1..{depth}"
if limit is None: if limit is None:
@ -573,7 +570,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
query = ( query = (
f"MATCH p=(n:{GraphElemType.ENTITY.value})" f"MATCH p=(n:{GraphElemType.ENTITY.value})"
f"{rel}(m:{GraphElemType.ENTITY.value}) " f"{rel}(m:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}" f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
f"RETURN p {limit_string}"
) )
return self.query(query) return self.query(query)
else: else:
@ -583,7 +581,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
query = ( query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-" f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-" f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-"
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS '{sub}' " f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS "
f"'{self._escape_quotes(sub)}' "
f"RETURN p {limit_string}" f"RETURN p {limit_string}"
) # if it contains the subjects ) # if it contains the subjects
result = self.query(query) result = self.query(query)
@ -794,8 +793,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
vertex_list.append(vertex) vertex_list.append(vertex)
return vertex_list, edge_list return vertex_list, edge_list
def _parser(self, entity_list: List[Dict[str, Any]]) -> str: def _convert_dict_to_str(self, entity_list: List[Dict[str, Any]]) -> str:
"""Parse entities to string.""" """Convert a list of entities to a formatted string representation."""
formatted_nodes = [ formatted_nodes = [
"{" "{"
+ ", ".join( + ", ".join(
@ -806,3 +805,90 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
for node in entity_list for node in entity_list
] ]
return f"""{", ".join(formatted_nodes)}""" return f"""{", ".join(formatted_nodes)}"""
def _escape_quotes(self, value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")
def upsert_doc_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to document include chunk."""
assert (
chunk.chunk_parent_id and chunk.chunk_parent_name
), "Chunk parent ID and name are required (document_include_chunk)"
edge = Edge(
sid=chunk.chunk_parent_id,
tid=chunk.chunk_id,
name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.INCLUDE.value,
src_type=GraphElemType.DOCUMENT.value,
dst_type=GraphElemType.CHUNK.value,
)
def upsert_chunk_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to chunk include chunk."""
assert (
chunk.chunk_parent_id and chunk.chunk_parent_name
), "Chunk parent ID and name are required (chunk_include_chunk)"
edge = Edge(
sid=chunk.chunk_parent_id,
tid=chunk.chunk_id,
name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.INCLUDE.value,
src_type=GraphElemType.CHUNK.value,
dst_type=GraphElemType.CHUNK.value,
)
def upsert_chunk_next_chunk(
self, chunk: ParagraphChunk, next_chunk: ParagraphChunk
):
"""Uperst the vertices and the edge in chunk_next_chunk."""
edge = Edge(
sid=chunk.chunk_id,
tid=next_chunk.chunk_id,
name=GraphElemType.NEXT.value,
edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.NEXT.value,
src_type=GraphElemType.CHUNK.value,
dst_type=GraphElemType.CHUNK.value,
)
def upsert_chunk_include_entity(
self, chunk: ParagraphChunk, entity: Vertex
) -> None:
"""Convert chunk to chunk include entity."""
edge = Edge(
sid=chunk.chunk_id,
tid=entity.vid,
name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.INCLUDE.value,
src_type=GraphElemType.CHUNK.value,
dst_type=GraphElemType.ENTITY.value,
)

View File

@ -3,15 +3,15 @@
import logging import logging
import os import os
import uuid import uuid
from typing import List, Optional from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk from dbgpt.core import Chunk
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.rag.transformer.graph_extractor import GraphExtractor from dbgpt.rag.transformer.graph_extractor import GraphExtractor
from dbgpt.storage.graph_store.graph import GraphElemType, MemoryGraph from dbgpt.storage.graph_store.graph import MemoryGraph
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
from dbgpt.storage.knowledge_graph.knowledge_graph import ( from dbgpt.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraph, BuiltinKnowledgeGraph,
BuiltinKnowledgeGraphConfig, BuiltinKnowledgeGraphConfig,
@ -59,6 +59,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=0.0, default=0.0,
description="Recall score of community search in knowledge graph", description="Recall score of community search in knowledge graph",
) )
knowledge_graph_chunk_search_top_size: int = Field(
default=5,
description="Top size of knowledge graph chunk search",
)
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph): class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
@ -124,7 +128,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
cfg.score_threshold = self._community_score_threshold cfg.score_threshold = self._community_score_threshold
self._community_store = CommunityStore( self._community_store = CommunityStore(
GraphStoreAdapterFactory.create(self._graph_store), self._graph_store_apdater,
CommunitySummarizer(self._llm_client, self._model_name), CommunitySummarizer(self._llm_client, self._model_name),
VectorStoreFactory.create( VectorStoreFactory.create(
self._vector_store_type, self._vector_store_type,
@ -138,145 +142,107 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
return self._config return self._config
async def aload_document(self, chunks: List[Chunk]) -> List[str]: async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist graph.""" """Extract and persist graph from the document file."""
data_list = self._parse_chunks(chunks) # parse the chunks by def _lod_doc_graph await self._aload_document_graph(chunks)
graph_of_all = MemoryGraph() await self._aload_triplet_graph(chunks)
# Support graph search by the document and the chunks
if self._graph_store.get_config().enable_document_graph:
doc_vid = str(uuid.uuid4())
doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node")
for chunk_index, chunk in enumerate(data_list):
if chunk["parent_id"] != "document":
# chunk -> include -> chunk
graph_of_all.upsert_vertex_and_edge(
src_vid=chunk["parent_id"],
src_name=chunk["parent_title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value,
)
else:
# document -> include -> chunk
graph_of_all.upsert_vertex_and_edge(
src_vid=doc_vid,
src_name=doc_name,
src_props={
"vertex_type": GraphElemType.DOCUMENT.value,
"content": "",
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value,
)
# chunk -> next -> chunk
if chunk_index >= 1:
graph_of_all.upsert_vertex_and_edge(
src_vid=data_list[chunk_index - 1]["id"],
src_name=data_list[chunk_index - 1]["title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": data_list[chunk_index - 1]["content"],
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.NEXT.value,
edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value,
)
# Support knowledge graph search by the entities and the relationships
if self._graph_store.get_config().enable_triplet_graph:
for chunk_index, chunk in enumerate(data_list):
# TODO: Use asyncio to extract graph to accelerate the process
# (attention to the CAP of the graph db)
graphs: List[MemoryGraph] = await self._graph_extractor.extract(
chunk["content"]
)
for graph in graphs:
graph_of_all.upsert_graph(graph)
# chunk -> include -> entity
if self._graph_store.get_config().enable_document_graph:
for vertex in graph.vertices():
graph_of_all.upsert_vertex_and_edge(
src_vid=chunk["id"],
src_name=chunk["title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
dst_vid=vertex.vid,
dst_name=vertex.name,
dst_props={
"vertex_type": GraphElemType.ENTITY.value,
"description": vertex.props.get("description", ""),
}, # note: description is only used for the entity
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value,
)
self._graph_store_apdater.upsert_graph(graph_of_all)
# use asyncio.gather
# tasks = [self._graph_extractor.extract(chunk.content) for chunk in chunks]
# results = await asyncio.gather(*tasks)
# for result in results:
# self._graph_store_apdater.upsert_graph(result[0])
# build communities and save
await self._community_store.build_communities() await self._community_store.build_communities()
return [chunk.chunk_id for chunk in chunks] return [chunk.chunk_id for chunk in chunks]
def _parse_chunks(slef, chunks: List[Chunk]): async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]:
"""Parse the chunks by anlyzing the markdown chunks.""" """Load the knowledge graph from the chunks.
# TODO: Need to refact.
data = [] The chunks include the doc structure.
"""
if not self._graph_store.get_config().document_graph_enabled:
return []
chunks: List[ParagraphChunk] = [
ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks
]
documment_chunk, chunks = self._load_chunks(chunks)
# upsert the document and chunks vertices
self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
self._graph_store_apdater.upsert_chunks(iter(chunks))
# upsert the document structure
for chunk_index, chunk in enumerate(chunks):
# document -> include -> chunk
if chunk.parent_is_document:
self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk)
else: # chunk -> include -> chunk
self._graph_store_apdater.upsert_chunk_include_chunk(chunk=chunk)
# chunk -> next -> chunk
if chunk_index >= 1:
self._graph_store_apdater.upsert_chunk_next_chunk(
chunk=chunks[chunk_index - 1], next_chunk=chunk
)
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
"""Load the knowledge graph from the chunks.
The chunks include the doc structure.
"""
if not self._graph_store.get_config().triplet_graph_enabled:
return
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
for chunk in chunks:
# TODO: Use asyncio to extract graph to accelerate the process
# (attention to the CAP of the graph db)
graphs: List[MemoryGraph] = await self._graph_extractor.extract(
chunk.content
)
for graph in graphs:
if document_graph_enabled:
# append the chunk id to the edge
for edge in graph.edges():
edge.set_prop("_chunk_id", chunk.chunk_id)
graph.append_edge(edge=edge)
# upsert the graph
self._graph_store_apdater.upsert_graph(graph)
# chunk -> include -> entity
if document_graph_enabled:
for vertex in graph.vertices():
self._graph_store_apdater.upsert_chunk_include_entity(
chunk=chunk, entity=vertex
)
def _load_chunks(
self, chunks: List[ParagraphChunk]
) -> Tuple[ParagraphChunk, List[ParagraphChunk]]:
"""Load the chunks, and add the parent-child relationship within chunks."""
# init default document
doc_id = str(uuid.uuid4())
doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node")
doc_chunk = ParagraphChunk(
chunk_id=doc_id,
chunk_name=doc_name,
)
# chunk.metadata = {"Header0": "title", "Header1": "title", ..., "source": "source_path"} # noqa: E501
for chunk_index, chunk in enumerate(chunks): for chunk_index, chunk in enumerate(chunks):
parent = None parent = None
directory_keys = list(chunk.metadata.keys())[:-1] directory_keys = list(chunk.metadata.keys())[
:-1
] # ex: ['Header0', 'Header1', 'Header2', ...]
parent_level = directory_keys[-2] if len(directory_keys) > 1 else None parent_level = directory_keys[-2] if len(directory_keys) > 1 else None
current_level = directory_keys[-1] if directory_keys else "Header0" current_level = directory_keys[-1] if directory_keys else "Header0"
chunk_data = { chunk.chunk_name = chunk.metadata.get(current_level, "none_header_chunk")
"id": chunk.chunk_id,
"title": chunk.metadata.get(current_level, "none_header_chunk"),
"directory_keys": directory_keys,
"level": current_level,
"content": chunk.content,
"parent_id": None,
"parent_title": None,
"type": "chunk",
"chunk_index": chunk_index,
}
# Find the parent chunk # Find the parent chunk for every chunk
# parent chunk -> chunk
if parent_level: if parent_level:
for parent_direct in reversed(directory_keys[:-1]): for parent_direct in reversed(directory_keys[:-1]):
parent_titile = chunk.metadata.get(parent_direct, None) parent_titile = chunk.metadata.get(parent_direct, None)
for n in range(chunk_index - 1, -1, -1): for n in reversed(range(chunk_index)):
metadata = chunks[n].metadata metadata = chunks[n].metadata
keys = list(metadata.keys())[:-1] keys = list(metadata.keys())[:-1]
if ( if (
@ -285,18 +251,36 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
and parent_titile == metadata.get(parent_direct) and parent_titile == metadata.get(parent_direct)
): ):
parent = chunks[n] parent = chunks[n]
chunk_data["parent_id"] = parent.chunk_id chunk.chunk_parent_id = parent.chunk_id
chunk_data["parent_title"] = parent_titile chunk.chunk_parent_name = parent_titile
chunk.parent_content = parent.content
break break
if chunk_index - n > len(directory_keys): if chunk_index - n > len(directory_keys):
break break
if chunk_data["parent_id"]: if chunk.chunk_parent_id:
break break
if not chunk_data["parent_id"]: if not chunk.chunk_parent_id:
chunk_data["parent_id"] = "document" chunk.chunk_parent_id = doc_id
data.append(chunk_data) chunk.chunk_parent_name = doc_name
return data chunk.parent_content = ""
chunk.parent_is_document = True
return doc_chunk, chunks
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
pass
async def asimilar_search_with_scores( async def asimilar_search_with_scores(
self, self,
@ -320,29 +304,29 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
subgraph = MemoryGraph() subgraph = MemoryGraph()
subgraph_for_doc = MemoryGraph() subgraph_for_doc = MemoryGraph()
enable_triplet_graph = self._graph_store.get_config().enable_triplet_graph triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
enable_document_graph = self._graph_store.get_config().enable_document_graph document_graph_enabled = self._graph_store.get_config().document_graph_enabled
if enable_triplet_graph: if triplet_graph_enabled:
subgraph: MemoryGraph = self._graph_store_apdater.explore( subgraph: MemoryGraph = self._graph_store_apdater.explore(
subs=keywords, limit=10, search_scope="knowledge_graph" subs=keywords, limit=topk, search_scope="knowledge_graph"
) )
if enable_document_graph: if document_graph_enabled:
keywords_for_document_graph = keywords keywords_for_document_graph = keywords
for vertex in subgraph.vertices(): for vertex in subgraph.vertices():
keywords_for_document_graph.append(vertex.name) keywords_for_document_graph.append(vertex.name)
subgraph_for_doc = self._graph_store_apdater.explore( subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords_for_document_graph, subs=keywords_for_document_graph,
limit=5, limit=self._config.knowledge_graph_chunk_search_top_size,
search_scope="document_graph", search_scope="document_graph",
) )
else: else:
if enable_document_graph: if document_graph_enabled:
subgraph_for_doc = self._graph_store_apdater.explore( subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords, subs=keywords,
limit=10, limit=self._config.knowledge_graph_chunk_search_top_size,
search_scope="document_graph", search_scope="document_graph",
) )

View File

@ -112,9 +112,10 @@ TUGRAPH_HOST=127.0.0.1
TUGRAPH_PORT=7687 TUGRAPH_PORT=7687
TUGRAPH_USERNAME=admin TUGRAPH_USERNAME=admin
TUGRAPH_PASSWORD=73@TuGraph TUGRAPH_PASSWORD=73@TuGraph
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
ENABLE_TRIPLET_GRAPH=True # enable the graph search for the triplets TRIPLET_GRAPH_ENABLED=True # enable the graph search for the triplets
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the number of the searched triplets in a retrieval
``` ```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 195 KiB