refactor: the document structure of GraphRAG (#2084)

This commit is contained in:
Appointat 2024-10-21 18:05:55 +08:00 committed by GitHub
parent 584f090847
commit 6d6667812b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 401 additions and 284 deletions

View File

@ -162,9 +162,11 @@ KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE=0.3
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE=20
KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE=0.0
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
ENABLE_TRIPLET_GRAPH=True # enable the graph search for triplets
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks
GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
TRIPLET_GRAPH_ENABLED=True # enable the graph search for triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the top size of knowledge graph search for chunks
### Chroma vector db config
#CHROMA_PERSIST_PATH=/root/DB-GPT/pilot/data

View File

@ -213,8 +213,8 @@ class Config(metaclass=Singleton):
# Vector Store Configuration
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
self.ENABLE_GRAPH_COMMUNITY_SUMMARY = (
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
self.GRAPH_COMMUNITY_SUMMARY_ENABLED = (
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
)
self.MILVUS_URL = os.getenv("MILVUS_URL", "127.0.0.1")
self.MILVUS_PORT = os.getenv("MILVUS_PORT", "19530")

View File

@ -49,6 +49,7 @@ class Chunk(Document):
chunk_id: str = Field(
default_factory=lambda: str(uuid.uuid4()), description="unique id for the chunk"
)
chunk_name: str = Field(default="", description="chunk name")
content: str = Field(default="", description="chunk text content")
metadata: Dict[str, Any] = Field(

View File

@ -128,7 +128,7 @@ class VectorStoreConnector:
def __rewrite_index_store_type(self, index_store_type):
# Rewrite Knowledge Graph Type
if CFG.ENABLE_GRAPH_COMMUNITY_SUMMARY:
if CFG.GRAPH_COMMUNITY_SUMMARY_ENABLED:
if index_store_type == "KnowledgeGraph":
return "CommunitySummaryKnowledgeGraph"
return index_store_type

View File

@ -27,11 +27,11 @@ class GraphStoreConfig(BaseModel):
default=False,
description="Enable graph community summary or not.",
)
enable_document_graph: bool = Field(
document_graph_enabled: bool = Field(
default=True,
description="Enable document graph search or not.",
)
enable_triplet_graph: bool = Field(
triplet_graph_enabled: bool = Field(
default=True,
description="Enable knowledge graph search or not.",
)
@ -48,11 +48,3 @@ class GraphStoreBase(ABC):
@abstractmethod
def get_config(self) -> GraphStoreConfig:
"""Get the graph store config."""
@abstractmethod
def _escape_quotes(self, text: str) -> str:
"""Escape single and double quotes in a string for queries."""
# @abstractmethod
# def _paser(self, entities: List[Vertex]) -> str:
# """Parse entities to string."""

View File

@ -19,8 +19,8 @@ class GraphElemType(Enum):
DOCUMENT = "document"
CHUNK = "chunk"
ENTITY = "entity" # view as general vertex in the general case
RELATION = "relation" # view as general edge in the general case
ENTITY = "entity" # default vertex type in knowledge graph
RELATION = "relation" # default edge type in knowledge graph
INCLUDE = "include"
NEXT = "next"
@ -39,7 +39,15 @@ class GraphElemType(Enum):
def is_edge(self) -> bool:
"""Check if the element is an edge."""
return not self.is_vertex()
return self in [
GraphElemType.RELATION,
GraphElemType.INCLUDE,
GraphElemType.NEXT,
GraphElemType.DOCUMENT_INCLUDE_CHUNK,
GraphElemType.CHUNK_INCLUDE_CHUNK,
GraphElemType.CHUNK_INCLUDE_ENTITY,
GraphElemType.CHUNK_NEXT_CHUNK,
]
class Direction(Enum):
@ -335,26 +343,6 @@ class MemoryGraph(Graph):
self._edge_count += 1
return True
def upsert_vertex_and_edge(
self,
src_vid: str,
src_name: str,
src_props: Dict[str, Any],
dst_vid: str,
dst_name: str,
dst_props: Dict[str, Any],
edge_name: str,
edge_type: str,
):
"""Uperst src and dst vertex, and edge."""
src_vertex = Vertex(src_vid, src_name, **src_props)
dst_vertex = Vertex(dst_vid, dst_name, **dst_props)
edge = Edge(src_vid, dst_vid, edge_name, **{"edge_type": edge_type})
self.upsert_vertex(src_vertex)
self.upsert_vertex(dst_vertex)
self.append_edge(edge)
def upsert_graph(self, graph: "MemoryGraph"):
"""Upsert a graph."""
for vertex in graph.vertices():

View File

@ -26,9 +26,3 @@ class MemoryGraphStore(GraphStoreBase):
def get_config(self):
"""Get the graph store config."""
return self._graph_store_config
def _escape_quotes(self, text: str) -> str:
"""Escape single and double quotes in a string for queries."""
raise NotImplementedError(
"_escape_quotes is not implemented by MemoryGraphStore"
)

View File

@ -80,16 +80,16 @@ class TuGraphStore(GraphStoreBase):
self._username = os.getenv("TUGRAPH_USERNAME", config.username)
self._password = os.getenv("TUGRAPH_PASSWORD", config.password)
self._enable_summary = (
os.getenv("ENABLE_GRAPH_COMMUNITY_SUMMARY", "").lower() == "true"
os.getenv("GRAPH_COMMUNITY_SUMMARY_ENABLED", "").lower() == "true"
or config.enable_summary
)
self._enable_document_graph = (
os.getenv("ENABLE_DOCUMENT_GRAPH", "").lower() == "true"
or config.enable_document_graph
os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true"
or config.document_graph_enabled
)
self._enable_triplet_graph = (
os.getenv("ENABLE_TRIPLET_GRAPH", "").lower() == "true"
or config.enable_triplet_graph
os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true"
or config.triplet_graph_enabled
)
self._plugin_names = (
os.getenv("TUGRAPH_PLUGIN_NAMES", "leiden").split(",")
@ -159,8 +159,3 @@ class TuGraphStore(GraphStoreBase):
f"'{name} Plugin', false, 'v1')"
)
self.conn.run(gql)
def _escape_quotes(self, value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")

View File

@ -4,7 +4,10 @@ import logging
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import Field
from dbgpt._private.pydantic import ConfigDict
from dbgpt.core import Chunk
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
from dbgpt.storage.graph_store.graph import Graph
@ -31,3 +34,14 @@ class KnowledgeGraphBase(IndexStoreBase, ABC):
@abstractmethod
def delete_by_ids(self, ids: str) -> List[str]:
"""Delete document by ids."""
class ParagraphChunk(Chunk):
"""Loaded chunk, used in GraphRAG."""
chunk_parent_id: str = Field(default=None, description="id of parent chunk")
chunk_parent_name: str = Field(default=None, description="parent chunk name")
parent_content: str = Field(default=None, description="parent chunk text content")
parent_is_document: bool = Field(
default=False, description="is parent chunk a document"
)

View File

@ -3,7 +3,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator, Iterator, List, Optional
from typing import AsyncGenerator, Iterator, List, Optional, Union
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import (
@ -14,6 +14,7 @@ from dbgpt.storage.graph_store.graph import (
MemoryGraph,
Vertex,
)
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
logger = logging.getLogger(__name__)
@ -91,17 +92,17 @@ class GraphStoreAdapter(ABC):
"""Upsert edge."""
@abstractmethod
def upsert_chunks(self, chunk: Iterator[Vertex]) -> None:
def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert chunk."""
@abstractmethod
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert documents."""
@abstractmethod
def upsert_relations(self, relations: Iterator[Edge]) -> None:
"""Upsert relations."""
@abstractmethod
def insert_triplet(self, sub: str, rel: str, obj: str) -> None:
"""Insert triplet."""
@ -110,6 +111,34 @@ class GraphStoreAdapter(ABC):
def upsert_graph(self, graph: Graph) -> None:
"""Insert graph."""
@abstractmethod
def upsert_doc_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to document include chunk."""
@abstractmethod
def upsert_chunk_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to chunk include chunk."""
@abstractmethod
def upsert_chunk_next_chunk(
self,
chunk: ParagraphChunk,
next_chunk: ParagraphChunk,
):
"""Uperst the vertices and the edge in chunk_next_chunk."""
@abstractmethod
def upsert_chunk_include_entity(
self, chunk: ParagraphChunk, entity: Vertex
) -> None:
"""Convert chunk to chunk include entity."""
@abstractmethod
def delete_document(self, chunk_id: str) -> None:
"""Delete document in graph store."""

View File

@ -2,7 +2,7 @@
import json
import logging
from typing import AsyncGenerator, Iterator, List, Optional, Tuple
from typing import AsyncGenerator, Iterator, List, Optional, Tuple, Union
from dbgpt.storage.graph_store.graph import (
Direction,
@ -16,6 +16,7 @@ from dbgpt.storage.graph_store.memgraph_store import (
MemoryGraphStore,
MemoryGraphStoreConfig,
)
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
logger = logging.getLogger(__name__)
@ -38,11 +39,11 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden."""
pass
[]
async def get_community(self, community_id: str) -> Community:
"""Get community."""
pass
raise NotImplementedError("Memory graph store does not have community")
def get_graph_config(self):
"""Get the graph store config."""
@ -96,11 +97,15 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
"""Upsert edges."""
pass
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert chunks."""
pass
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert documents."""
pass
@ -108,6 +113,32 @@ class MemGraphStoreAdapter(GraphStoreAdapter):
"""Upsert relations."""
pass
def upsert_doc_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to document include chunk."""
pass
def upsert_chunk_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to chunk include chunk."""
pass
def upsert_chunk_next_chunk(
self, chunk: ParagraphChunk, next_chunk: ParagraphChunk
):
"""Uperst the vertices and the edge in chunk_next_chunk."""
pass
def upsert_chunk_include_entity(
self, chunk: ParagraphChunk, entity: Vertex
) -> None:
"""Convert chunk to chunk include entity."""
pass
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
self._graph_store._graph.append_edge(Edge(subj, obj, rel))

View File

@ -23,6 +23,7 @@ from dbgpt.storage.graph_store.graph import (
Vertex,
)
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
logger = logging.getLogger(__name__)
@ -137,11 +138,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
"""Upsert entities."""
entity_list = [
{
"id": self.graph_store._escape_quotes(entity.vid),
"name": self.graph_store._escape_quotes(entity.name),
"description": self.graph_store._escape_quotes(
entity.get_prop("description")
)
"id": self._escape_quotes(entity.vid),
"name": self._escape_quotes(entity.name),
"description": self._escape_quotes(entity.get_prop("description"))
or "",
"_document_id": "0",
"_chunk_id": "0",
@ -152,7 +151,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
entity_query = (
f"CALL db.upsertVertex("
f'"{GraphElemType.ENTITY.value}", '
f"[{self._parser(entity_list)}])"
f"[{self._convert_dict_to_str(entity_list)}])"
)
self.graph_store.conn.run(query=entity_query)
@ -162,64 +161,81 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
"""Upsert edges."""
edge_list = [
{
"sid": self.graph_store._escape_quotes(edge.sid),
"tid": self.graph_store._escape_quotes(edge.tid),
"id": self.graph_store._escape_quotes(edge.name),
"name": self.graph_store._escape_quotes(edge.name),
"description": self.graph_store._escape_quotes(
edge.get_prop("description")
)
or "",
"_chunk_id": self.graph_store._escape_quotes(edge.get_prop("_chunk_id"))
or "",
"sid": self._escape_quotes(edge.sid),
"tid": self._escape_quotes(edge.tid),
"id": self._escape_quotes(edge.name),
"name": self._escape_quotes(edge.name),
"description": self._escape_quotes(edge.get_prop("description")) or "",
"_chunk_id": self._escape_quotes(edge.get_prop("_chunk_id")) or "",
}
for edge in edges
]
relation_query = f"""CALL db.upsertEdge("{edge_type}",
{{type:"{src_type}", key:"sid"}},
{{type:"{dst_type}", key:"tid"}},
[{self._parser(edge_list)}])"""
[{self._convert_dict_to_str(edge_list)}])"""
self.graph_store.conn.run(query=relation_query)
def upsert_chunks(self, chunks: Iterator[Vertex]) -> None:
def upsert_chunks(
self, chunks: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert chunks."""
chunk_list = [
{
"id": self.graph_store._escape_quotes(chunk.vid),
"name": self.graph_store._escape_quotes(chunk.name),
"content": self.graph_store._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks
]
chunks_list = list(chunks)
if chunks_list and isinstance(chunks_list[0], ParagraphChunk):
chunk_list = [
{
"id": self._escape_quotes(chunk.chunk_id),
"name": self._escape_quotes(chunk.chunk_name),
"content": self._escape_quotes(chunk.content),
}
for chunk in chunks_list
]
else:
chunk_list = [
{
"id": self._escape_quotes(chunk.vid),
"name": self._escape_quotes(chunk.name),
"content": self._escape_quotes(chunk.get_prop("content")),
}
for chunk in chunks_list
]
chunk_query = (
f"CALL db.upsertVertex("
f'"{GraphElemType.CHUNK.value}", '
f"[{self._parser(chunk_list)}])"
f"[{self._convert_dict_to_str(chunk_list)}])"
)
self.graph_store.conn.run(query=chunk_query)
def upsert_documents(self, documents: Iterator[Vertex]) -> None:
def upsert_documents(
self, documents: Union[Iterator[Vertex], Iterator[ParagraphChunk]]
) -> None:
"""Upsert documents."""
document_list = [
{
"id": self.graph_store._escape_quotes(document.vid),
"name": self.graph_store._escape_quotes(document.name),
"content": self.graph_store._escape_quotes(document.get_prop("content"))
or "",
}
for document in documents
]
documents_list = list(documents)
if documents_list and isinstance(documents_list[0], ParagraphChunk):
document_list = [
{
"id": self._escape_quotes(document.chunk_id),
"name": self._escape_quotes(document.chunk_name),
"content": "",
}
for document in documents_list
]
else:
document_list = [
{
"id": self._escape_quotes(document.vid),
"name": self._escape_quotes(document.name),
"content": self._escape_quotes(document.get_prop("content")) or "",
}
for document in documents_list
]
document_query = (
"CALL db.upsertVertex("
f'"{GraphElemType.DOCUMENT.value}", '
f"[{self._parser(document_list)}])"
f"[{self._convert_dict_to_str(document_list)}])"
)
self.graph_store.conn.run(query=document_query)
def upsert_relations(self, relations: Iterator[Edge]) -> None:
"""Upsert relations."""
pass
def insert_triplet(self, subj: str, rel: str, obj: str) -> None:
"""Add triplet."""
subj_escaped = subj.replace("'", "\\'").replace('"', '\\"')
@ -465,10 +481,6 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
"properties": graph_properties,
})
gql = f"""CALL db.createVertexLabelByJson('{data}')"""
gql_check_exist = (
f"""CALL db.getLabelSchema('VERTEX', '{graph_elem_type.value}')"""
)
else: # edge
def edge_direction(graph_elem_type: GraphElemType) -> List[List[str]]:
@ -501,20 +513,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
})
gql = f"""CALL db.createEdgeLabelByJson('{data}')"""
gql_check_exist = (
f"""CALL db.getLabelSchema('EDGE', '{graph_elem_type.value}')"""
)
# Make sure the graph label is identical
try:
self.graph_store.conn.run(
gql_check_exist
) # if not exist, qurying raises an exception
except Exception:
self.graph_store.conn.run(gql) # create the graph label
return
logger.info(f"Graph label {graph_elem_type.value} already exists.")
self.graph_store.conn.run(gql)
def truncate(self):
"""Truncate Graph."""
@ -542,7 +541,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
self,
subs: List[str],
direct: Direction = Direction.BOTH,
depth: Optional[int] = None,
depth: int = 3,
limit: Optional[int] = None,
search_scope: Optional[
Literal["knowledge_graph", "document_graph"]
@ -552,10 +551,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
if not subs:
return MemoryGraph()
if depth is None or depth < 0 or depth > self.MAX_HIERARCHY_LEVEL:
# TODO: to be discussed, be none or MAX_HIERARCHY_LEVEL
# depth_string = ".."
depth = self.MAX_HIERARCHY_LEVEL
if depth < 0:
depth = 3
depth_string = f"1..{depth}"
if limit is None:
@ -573,7 +570,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
query = (
f"MATCH p=(n:{GraphElemType.ENTITY.value})"
f"{rel}(m:{GraphElemType.ENTITY.value}) "
f"WHERE n.id IN {subs} RETURN p {limit_string}"
f"WHERE n.id IN {[self._escape_quotes(sub) for sub in subs]} "
f"RETURN p {limit_string}"
)
return self.query(query)
else:
@ -583,7 +581,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
query = (
f"MATCH p=(n:{GraphElemType.DOCUMENT.value})-"
f"[r:{GraphElemType.INCLUDE.value}*{depth_string}]-"
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS '{sub}' "
f"(m:{GraphElemType.CHUNK.value})WHERE m.content CONTAINS "
f"'{self._escape_quotes(sub)}' "
f"RETURN p {limit_string}"
) # if it contains the subjects
result = self.query(query)
@ -794,8 +793,8 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
vertex_list.append(vertex)
return vertex_list, edge_list
def _parser(self, entity_list: List[Dict[str, Any]]) -> str:
"""Parse entities to string."""
def _convert_dict_to_str(self, entity_list: List[Dict[str, Any]]) -> str:
"""Convert a list of entities to a formatted string representation."""
formatted_nodes = [
"{"
+ ", ".join(
@ -806,3 +805,90 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
for node in entity_list
]
return f"""{", ".join(formatted_nodes)}"""
def _escape_quotes(self, value: str) -> str:
"""Escape single and double quotes in a string for queries."""
if value is not None:
return value.replace("'", "").replace('"', "")
def upsert_doc_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to document include chunk."""
assert (
chunk.chunk_parent_id and chunk.chunk_parent_name
), "Chunk parent ID and name are required (document_include_chunk)"
edge = Edge(
sid=chunk.chunk_parent_id,
tid=chunk.chunk_id,
name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.INCLUDE.value,
src_type=GraphElemType.DOCUMENT.value,
dst_type=GraphElemType.CHUNK.value,
)
def upsert_chunk_include_chunk(
self,
chunk: ParagraphChunk,
) -> None:
"""Convert chunk to chunk include chunk."""
assert (
chunk.chunk_parent_id and chunk.chunk_parent_name
), "Chunk parent ID and name are required (chunk_include_chunk)"
edge = Edge(
sid=chunk.chunk_parent_id,
tid=chunk.chunk_id,
name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.INCLUDE.value,
src_type=GraphElemType.CHUNK.value,
dst_type=GraphElemType.CHUNK.value,
)
def upsert_chunk_next_chunk(
self, chunk: ParagraphChunk, next_chunk: ParagraphChunk
):
"""Uperst the vertices and the edge in chunk_next_chunk."""
edge = Edge(
sid=chunk.chunk_id,
tid=next_chunk.chunk_id,
name=GraphElemType.NEXT.value,
edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.NEXT.value,
src_type=GraphElemType.CHUNK.value,
dst_type=GraphElemType.CHUNK.value,
)
def upsert_chunk_include_entity(
self, chunk: ParagraphChunk, entity: Vertex
) -> None:
"""Convert chunk to chunk include entity."""
edge = Edge(
sid=chunk.chunk_id,
tid=entity.vid,
name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value,
)
self.upsert_edge(
edges=iter([edge]),
edge_type=GraphElemType.INCLUDE.value,
src_type=GraphElemType.CHUNK.value,
dst_type=GraphElemType.ENTITY.value,
)

View File

@ -3,15 +3,15 @@
import logging
import os
import uuid
from typing import List, Optional
from typing import List, Optional, Tuple
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import Chunk
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.rag.transformer.graph_extractor import GraphExtractor
from dbgpt.storage.graph_store.graph import GraphElemType, MemoryGraph
from dbgpt.storage.graph_store.graph import MemoryGraph
from dbgpt.storage.knowledge_graph.base import ParagraphChunk
from dbgpt.storage.knowledge_graph.community.community_store import CommunityStore
from dbgpt.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory
from dbgpt.storage.knowledge_graph.knowledge_graph import (
BuiltinKnowledgeGraph,
BuiltinKnowledgeGraphConfig,
@ -59,6 +59,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=0.0,
description="Recall score of community search in knowledge graph",
)
knowledge_graph_chunk_search_top_size: int = Field(
default=5,
description="Top size of knowledge graph chunk search",
)
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
@ -124,7 +128,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
cfg.score_threshold = self._community_score_threshold
self._community_store = CommunityStore(
GraphStoreAdapterFactory.create(self._graph_store),
self._graph_store_apdater,
CommunitySummarizer(self._llm_client, self._model_name),
VectorStoreFactory.create(
self._vector_store_type,
@ -138,145 +142,107 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
return self._config
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist graph."""
data_list = self._parse_chunks(chunks) # parse the chunks by def _lod_doc_graph
graph_of_all = MemoryGraph()
# Support graph search by the document and the chunks
if self._graph_store.get_config().enable_document_graph:
doc_vid = str(uuid.uuid4())
doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node")
for chunk_index, chunk in enumerate(data_list):
if chunk["parent_id"] != "document":
# chunk -> include -> chunk
graph_of_all.upsert_vertex_and_edge(
src_vid=chunk["parent_id"],
src_name=chunk["parent_title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_CHUNK.value,
)
else:
# document -> include -> chunk
graph_of_all.upsert_vertex_and_edge(
src_vid=doc_vid,
src_name=doc_name,
src_props={
"vertex_type": GraphElemType.DOCUMENT.value,
"content": "",
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.DOCUMENT_INCLUDE_CHUNK.value,
)
# chunk -> next -> chunk
if chunk_index >= 1:
graph_of_all.upsert_vertex_and_edge(
src_vid=data_list[chunk_index - 1]["id"],
src_name=data_list[chunk_index - 1]["title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": data_list[chunk_index - 1]["content"],
},
dst_vid=chunk["id"],
dst_name=chunk["title"],
dst_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
edge_name=GraphElemType.NEXT.value,
edge_type=GraphElemType.CHUNK_NEXT_CHUNK.value,
)
# Support knowledge graph search by the entities and the relationships
if self._graph_store.get_config().enable_triplet_graph:
for chunk_index, chunk in enumerate(data_list):
# TODO: Use asyncio to extract graph to accelerate the process
# (attention to the CAP of the graph db)
graphs: List[MemoryGraph] = await self._graph_extractor.extract(
chunk["content"]
)
for graph in graphs:
graph_of_all.upsert_graph(graph)
# chunk -> include -> entity
if self._graph_store.get_config().enable_document_graph:
for vertex in graph.vertices():
graph_of_all.upsert_vertex_and_edge(
src_vid=chunk["id"],
src_name=chunk["title"],
src_props={
"vertex_type": GraphElemType.CHUNK.value,
"content": chunk["content"],
},
dst_vid=vertex.vid,
dst_name=vertex.name,
dst_props={
"vertex_type": GraphElemType.ENTITY.value,
"description": vertex.props.get("description", ""),
}, # note: description is only used for the entity
edge_name=GraphElemType.INCLUDE.value,
edge_type=GraphElemType.CHUNK_INCLUDE_ENTITY.value,
)
self._graph_store_apdater.upsert_graph(graph_of_all)
# use asyncio.gather
# tasks = [self._graph_extractor.extract(chunk.content) for chunk in chunks]
# results = await asyncio.gather(*tasks)
# for result in results:
# self._graph_store_apdater.upsert_graph(result[0])
# build communities and save
"""Extract and persist graph from the document file."""
await self._aload_document_graph(chunks)
await self._aload_triplet_graph(chunks)
await self._community_store.build_communities()
return [chunk.chunk_id for chunk in chunks]
def _parse_chunks(slef, chunks: List[Chunk]):
"""Parse the chunks by anlyzing the markdown chunks."""
# TODO: Need to refact.
data = []
async def _aload_document_graph(self, chunks: List[Chunk]) -> List[str]:
"""Load the knowledge graph from the chunks.
The chunks include the doc structure.
"""
if not self._graph_store.get_config().document_graph_enabled:
return []
chunks: List[ParagraphChunk] = [
ParagraphChunk.model_validate(chunk.model_dump()) for chunk in chunks
]
documment_chunk, chunks = self._load_chunks(chunks)
# upsert the document and chunks vertices
self._graph_store_apdater.upsert_documents(iter([documment_chunk]))
self._graph_store_apdater.upsert_chunks(iter(chunks))
# upsert the document structure
for chunk_index, chunk in enumerate(chunks):
# document -> include -> chunk
if chunk.parent_is_document:
self._graph_store_apdater.upsert_doc_include_chunk(chunk=chunk)
else: # chunk -> include -> chunk
self._graph_store_apdater.upsert_chunk_include_chunk(chunk=chunk)
# chunk -> next -> chunk
if chunk_index >= 1:
self._graph_store_apdater.upsert_chunk_next_chunk(
chunk=chunks[chunk_index - 1], next_chunk=chunk
)
async def _aload_triplet_graph(self, chunks: List[Chunk]) -> None:
"""Load the knowledge graph from the chunks.
The chunks include the doc structure.
"""
if not self._graph_store.get_config().triplet_graph_enabled:
return
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
for chunk in chunks:
# TODO: Use asyncio to extract graph to accelerate the process
# (attention to the CAP of the graph db)
graphs: List[MemoryGraph] = await self._graph_extractor.extract(
chunk.content
)
for graph in graphs:
if document_graph_enabled:
# append the chunk id to the edge
for edge in graph.edges():
edge.set_prop("_chunk_id", chunk.chunk_id)
graph.append_edge(edge=edge)
# upsert the graph
self._graph_store_apdater.upsert_graph(graph)
# chunk -> include -> entity
if document_graph_enabled:
for vertex in graph.vertices():
self._graph_store_apdater.upsert_chunk_include_entity(
chunk=chunk, entity=vertex
)
def _load_chunks(
self, chunks: List[ParagraphChunk]
) -> Tuple[ParagraphChunk, List[ParagraphChunk]]:
"""Load the chunks, and add the parent-child relationship within chunks."""
# init default document
doc_id = str(uuid.uuid4())
doc_name = os.path.basename(chunks[0].metadata["source"] or "Text_Node")
doc_chunk = ParagraphChunk(
chunk_id=doc_id,
chunk_name=doc_name,
)
# chunk.metadata = {"Header0": "title", "Header1": "title", ..., "source": "source_path"} # noqa: E501
for chunk_index, chunk in enumerate(chunks):
parent = None
directory_keys = list(chunk.metadata.keys())[:-1]
directory_keys = list(chunk.metadata.keys())[
:-1
] # ex: ['Header0', 'Header1', 'Header2', ...]
parent_level = directory_keys[-2] if len(directory_keys) > 1 else None
current_level = directory_keys[-1] if directory_keys else "Header0"
chunk_data = {
"id": chunk.chunk_id,
"title": chunk.metadata.get(current_level, "none_header_chunk"),
"directory_keys": directory_keys,
"level": current_level,
"content": chunk.content,
"parent_id": None,
"parent_title": None,
"type": "chunk",
"chunk_index": chunk_index,
}
chunk.chunk_name = chunk.metadata.get(current_level, "none_header_chunk")
# Find the parent chunk
# Find the parent chunk for every chunk
# parent chunk -> chunk
if parent_level:
for parent_direct in reversed(directory_keys[:-1]):
parent_titile = chunk.metadata.get(parent_direct, None)
for n in range(chunk_index - 1, -1, -1):
for n in reversed(range(chunk_index)):
metadata = chunks[n].metadata
keys = list(metadata.keys())[:-1]
if (
@ -285,18 +251,36 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
and parent_titile == metadata.get(parent_direct)
):
parent = chunks[n]
chunk_data["parent_id"] = parent.chunk_id
chunk_data["parent_title"] = parent_titile
chunk.chunk_parent_id = parent.chunk_id
chunk.chunk_parent_name = parent_titile
chunk.parent_content = parent.content
break
if chunk_index - n > len(directory_keys):
break
if chunk_data["parent_id"]:
if chunk.chunk_parent_id:
break
if not chunk_data["parent_id"]:
chunk_data["parent_id"] = "document"
data.append(chunk_data)
return data
if not chunk.chunk_parent_id:
chunk.chunk_parent_id = doc_id
chunk.chunk_parent_name = doc_name
chunk.parent_content = ""
chunk.parent_is_document = True
return doc_chunk, chunks
def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
pass
async def asimilar_search_with_scores(
self,
@ -320,29 +304,29 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
subgraph = MemoryGraph()
subgraph_for_doc = MemoryGraph()
enable_triplet_graph = self._graph_store.get_config().enable_triplet_graph
enable_document_graph = self._graph_store.get_config().enable_document_graph
triplet_graph_enabled = self._graph_store.get_config().triplet_graph_enabled
document_graph_enabled = self._graph_store.get_config().document_graph_enabled
if enable_triplet_graph:
if triplet_graph_enabled:
subgraph: MemoryGraph = self._graph_store_apdater.explore(
subs=keywords, limit=10, search_scope="knowledge_graph"
subs=keywords, limit=topk, search_scope="knowledge_graph"
)
if enable_document_graph:
if document_graph_enabled:
keywords_for_document_graph = keywords
for vertex in subgraph.vertices():
keywords_for_document_graph.append(vertex.name)
subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords_for_document_graph,
limit=5,
limit=self._config.knowledge_graph_chunk_search_top_size,
search_scope="document_graph",
)
else:
if enable_document_graph:
if document_graph_enabled:
subgraph_for_doc = self._graph_store_apdater.explore(
subs=keywords,
limit=10,
limit=self._config.knowledge_graph_chunk_search_top_size,
search_scope="document_graph",
)

View File

@ -112,9 +112,10 @@ TUGRAPH_HOST=127.0.0.1
TUGRAPH_PORT=7687
TUGRAPH_USERNAME=admin
TUGRAPH_PASSWORD=73@TuGraph
ENABLE_GRAPH_COMMUNITY_SUMMARY=True # enable the graph community summary
ENABLE_TRIPLET_GRAPH=True # enable the graph search for the triplets
ENABLE_DOCUMENT_GRAPH=True # enable the graph search for documents and chunks
GRAPH_COMMUNITY_SUMMARY_ENABLED=True # enable the graph community summary
TRIPLET_GRAPH_ENABLED=True # enable the graph search for the triplets
DOCUMENT_GRAPH_ENABLED=True # enable the graph search for documents and chunks
KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE=5 # the number of the searched triplets in a retrieval
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 195 KiB