feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)

Co-authored-by: Florian <fanzhidongyzby@163.com>
Co-authored-by: KingSkyLi <15566300566@163.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
M1n9X
2024-08-30 21:59:44 +08:00
committed by GitHub
parent 471689ba20
commit 759f7d99cc
59 changed files with 29316 additions and 411 deletions

View File

@@ -0,0 +1 @@
"""Community Module."""

View File

@@ -0,0 +1,73 @@
"""Define Classes about Community."""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.graph import Graph
logger = logging.getLogger(__name__)
@dataclass
class Community:
"""Community class."""
id: str
data: Optional[Graph] = None
summary: Optional[str] = None
@dataclass
class CommunityTree:
"""Represents a community tree."""
class CommunityStoreAdapter(ABC):
"""Community Store Adapter."""
def __init__(self, graph_store: GraphStoreBase):
"""Initialize Community Store Adapter."""
self._graph_store = graph_store
@property
def graph_store(self) -> GraphStoreBase:
"""Get graph store."""
return self._graph_store
@abstractmethod
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery."""
@abstractmethod
async def get_community(self, community_id: str) -> Community:
"""Get community."""
class CommunityMetastore(ABC):
"""Community metastore class."""
@abstractmethod
def get(self, community_id: str) -> Community:
"""Get community."""
@abstractmethod
def list(self) -> List[Community]:
"""Get all communities."""
@abstractmethod
async def search(self, query: str) -> List[Community]:
"""Search communities relevant to query."""
@abstractmethod
async def save(self, communities: List[Community]):
"""Save communities."""
@abstractmethod
async def truncate(self):
"""Truncate all communities."""
@abstractmethod
def drop(self):
"""Drop community metastore."""

View File

@@ -0,0 +1,63 @@
"""Builtin Community metastore."""
import logging
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.storage.knowledge_graph.community.base import Community, CommunityMetastore
from dbgpt.storage.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class BuiltinCommunityMetastore(CommunityMetastore):
"""Builtin Community metastore."""
def __init__(
self, vector_store: VectorStoreBase, rdb_store: Optional[RDBMSConnector] = None
):
"""Initialize Community metastore."""
self._vector_store = vector_store
self._rdb_store = rdb_store
config = self._vector_store.get_config()
self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold
def get(self, community_id: str) -> Community:
"""Get community."""
raise NotImplementedError("Get community not allowed")
def list(self) -> List[Community]:
"""Get all communities."""
raise NotImplementedError("List communities not allowed")
async def search(self, query: str) -> List[Community]:
"""Search communities relevant to query."""
chunks = await self._vector_store.asimilar_search_with_scores(
query, self._topk, self._score_threshold
)
return [Community(id=chunk.chunk_id, summary=chunk.content) for chunk in chunks]
async def save(self, communities: List[Community]):
"""Save communities."""
chunks = [
Chunk(id=c.id, content=c.summary, metadata={"total": len(communities)})
for c in communities
]
await self._vector_store.aload_document_with_limit(
chunks, self._max_chunks_once_load, self._max_threads
)
logger.info(f"Save {len(communities)} communities")
async def truncate(self):
"""Truncate community metastore."""
self._vector_store.truncate()
def drop(self):
"""Drop community metastore."""
if self._vector_store.vector_name_exists():
self._vector_store.delete_vector_name(self._vector_space)

View File

@@ -0,0 +1,83 @@
"""Define the CommunityStore class."""
import logging
from typing import List
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.storage.knowledge_graph.community.base import (
Community,
CommunityStoreAdapter,
)
from dbgpt.storage.knowledge_graph.community.community_metastore import (
BuiltinCommunityMetastore,
)
from dbgpt.storage.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class CommunityStore:
"""CommunityStore Class."""
def __init__(
self,
community_store_adapter: CommunityStoreAdapter,
community_summarizer: CommunitySummarizer,
vector_store: VectorStoreBase,
):
"""Initialize the CommunityStore class."""
self._community_store_adapter = community_store_adapter
self._community_summarizer = community_summarizer
self._meta_store = BuiltinCommunityMetastore(vector_store)
async def build_communities(self):
"""Discover communities."""
community_ids = await (self._community_store_adapter.discover_communities())
# summarize communities
communities = []
for community_id in community_ids:
community = await (
self._community_store_adapter.get_community(community_id)
)
graph = community.data.format()
if not graph:
break
community.summary = await (
self._community_summarizer.summarize(graph=graph)
)
communities.append(community)
logger.info(
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
)
# truncate then save new summaries
await self._meta_store.truncate()
await self._meta_store.save(communities)
async def search_communities(self, query: str) -> List[Community]:
"""Search communities."""
return await self._meta_store.search(query)
def truncate(self):
"""Truncate community store."""
logger.info("Truncate community metastore")
self._meta_store.truncate()
logger.info("Truncate community summarizer")
self._community_summarizer.truncate()
logger.info("Truncate graph")
self._community_store_adapter.graph_store.truncate()
def drop(self):
"""Drop community store."""
logger.info("Remove community metastore")
self._meta_store.drop()
logger.info("Remove community summarizer")
self._community_summarizer.drop()
logger.info("Remove graph")
self._community_store_adapter.graph_store.drop()

View File

@@ -0,0 +1,30 @@
"""CommunityStoreAdapter factory."""
import logging
from dbgpt.storage.graph_store.base import GraphStoreBase
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
TuGraphCommunityStoreAdapter,
)
logger = logging.getLogger(__name__)
class CommunityStoreAdapterFactory:
"""Factory for community store adapter."""
@staticmethod
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
"""Create a CommunityStoreAdapter instance.
Args:
- graph_store_type: graph store type Memory, TuGraph, Neo4j
"""
if isinstance(graph_store, TuGraphStore):
return TuGraphCommunityStoreAdapter(graph_store)
else:
raise Exception(
"create community store adapter for %s failed",
graph_store.__class__.__name__,
)

View File

@@ -0,0 +1,52 @@
"""TuGraph Community Store Adapter."""
import json
import logging
from typing import List
from dbgpt.storage.graph_store.graph import MemoryGraph
from dbgpt.storage.knowledge_graph.community.base import (
Community,
CommunityStoreAdapter,
)
logger = logging.getLogger(__name__)
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
"""TuGraph Community Store Adapter."""
MAX_HIERARCHY_LEVEL = 3
async def discover_communities(self, **kwargs) -> List[str]:
"""Run community discovery with leiden."""
mg = self._graph_store.query(
"CALL db.plugin.callPlugin"
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
)
result = mg.get_vertex("json_node").get_prop("description")
community_ids = json.loads(result)["community_id_list"]
logger.info(f"Discovered {len(community_ids)} communities.")
return community_ids
async def get_community(self, community_id: str) -> Community:
"""Get community."""
query = (
f"MATCH (n:{self._graph_store.get_vertex_type()})"
f"WHERE n._community_id = '{community_id}' RETURN n"
)
edge_query = (
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
f"[r:{self._graph_store.get_edge_type()}]-"
f"(m:{self._graph_store.get_vertex_type()})"
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
)
all_vertex_graph = self._graph_store.aquery(query)
all_edge_graph = self._graph_store.aquery(edge_query)
all_graph = MemoryGraph()
for vertex in all_vertex_graph.vertices():
all_graph.upsert_vertex(vertex)
for edge in all_edge_graph.edges():
all_graph.append_edge(edge)
return Community(id=community_id, data=all_graph)