mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 11:39:11 +00:00
✨ feat(GraphRAG): enhance GraphRAG by graph community summary (#1801)
Co-authored-by: Florian <fanzhidongyzby@163.com> Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: yvonneyx <zhuyuxin0627@gmail.com>
This commit is contained in:
1
dbgpt/storage/knowledge_graph/community/__init__.py
Normal file
1
dbgpt/storage/knowledge_graph/community/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Community Module."""
|
73
dbgpt/storage/knowledge_graph/community/base.py
Normal file
73
dbgpt/storage/knowledge_graph/community/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Define Classes about Community."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.graph import Graph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Community:
|
||||
"""Community class."""
|
||||
|
||||
id: str
|
||||
data: Optional[Graph] = None
|
||||
summary: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommunityTree:
|
||||
"""Represents a community tree."""
|
||||
|
||||
|
||||
class CommunityStoreAdapter(ABC):
|
||||
"""Community Store Adapter."""
|
||||
|
||||
def __init__(self, graph_store: GraphStoreBase):
|
||||
"""Initialize Community Store Adapter."""
|
||||
self._graph_store = graph_store
|
||||
|
||||
@property
|
||||
def graph_store(self) -> GraphStoreBase:
|
||||
"""Get graph store."""
|
||||
return self._graph_store
|
||||
|
||||
@abstractmethod
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
||||
|
||||
class CommunityMetastore(ABC):
|
||||
"""Community metastore class."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
|
||||
@abstractmethod
|
||||
def list(self) -> List[Community]:
|
||||
"""Get all communities."""
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str) -> List[Community]:
|
||||
"""Search communities relevant to query."""
|
||||
|
||||
@abstractmethod
|
||||
async def save(self, communities: List[Community]):
|
||||
"""Save communities."""
|
||||
|
||||
@abstractmethod
|
||||
async def truncate(self):
|
||||
"""Truncate all communities."""
|
||||
|
||||
@abstractmethod
|
||||
def drop(self):
|
||||
"""Drop community metastore."""
|
@@ -0,0 +1,63 @@
|
||||
"""Builtin Community metastore."""
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, CommunityMetastore
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinCommunityMetastore(CommunityMetastore):
|
||||
"""Builtin Community metastore."""
|
||||
|
||||
def __init__(
|
||||
self, vector_store: VectorStoreBase, rdb_store: Optional[RDBMSConnector] = None
|
||||
):
|
||||
"""Initialize Community metastore."""
|
||||
self._vector_store = vector_store
|
||||
self._rdb_store = rdb_store
|
||||
|
||||
config = self._vector_store.get_config()
|
||||
self._vector_space = config.name
|
||||
self._max_chunks_once_load = config.max_chunks_once_load
|
||||
self._max_threads = config.max_threads
|
||||
self._topk = config.topk
|
||||
self._score_threshold = config.score_threshold
|
||||
|
||||
def get(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
raise NotImplementedError("Get community not allowed")
|
||||
|
||||
def list(self) -> List[Community]:
|
||||
"""Get all communities."""
|
||||
raise NotImplementedError("List communities not allowed")
|
||||
|
||||
async def search(self, query: str) -> List[Community]:
|
||||
"""Search communities relevant to query."""
|
||||
chunks = await self._vector_store.asimilar_search_with_scores(
|
||||
query, self._topk, self._score_threshold
|
||||
)
|
||||
return [Community(id=chunk.chunk_id, summary=chunk.content) for chunk in chunks]
|
||||
|
||||
async def save(self, communities: List[Community]):
|
||||
"""Save communities."""
|
||||
chunks = [
|
||||
Chunk(id=c.id, content=c.summary, metadata={"total": len(communities)})
|
||||
for c in communities
|
||||
]
|
||||
await self._vector_store.aload_document_with_limit(
|
||||
chunks, self._max_chunks_once_load, self._max_threads
|
||||
)
|
||||
logger.info(f"Save {len(communities)} communities")
|
||||
|
||||
async def truncate(self):
|
||||
"""Truncate community metastore."""
|
||||
self._vector_store.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop community metastore."""
|
||||
if self._vector_store.vector_name_exists():
|
||||
self._vector_store.delete_vector_name(self._vector_space)
|
83
dbgpt/storage/knowledge_graph/community/community_store.py
Normal file
83
dbgpt/storage/knowledge_graph/community/community_store.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Define the CommunityStore class."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
from dbgpt.storage.knowledge_graph.community.community_metastore import (
|
||||
BuiltinCommunityMetastore,
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityStore:
|
||||
"""CommunityStore Class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
community_store_adapter: CommunityStoreAdapter,
|
||||
community_summarizer: CommunitySummarizer,
|
||||
vector_store: VectorStoreBase,
|
||||
):
|
||||
"""Initialize the CommunityStore class."""
|
||||
self._community_store_adapter = community_store_adapter
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
"""Discover communities."""
|
||||
community_ids = await (self._community_store_adapter.discover_communities())
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await (
|
||||
self._community_store_adapter.get_community(community_id)
|
||||
)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
|
||||
community.summary = await (
|
||||
self._community_summarizer.summarize(graph=graph)
|
||||
)
|
||||
communities.append(community)
|
||||
logger.info(
|
||||
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
||||
)
|
||||
|
||||
# truncate then save new summaries
|
||||
await self._meta_store.truncate()
|
||||
await self._meta_store.save(communities)
|
||||
|
||||
async def search_communities(self, query: str) -> List[Community]:
|
||||
"""Search communities."""
|
||||
return await self._meta_store.search(query)
|
||||
|
||||
def truncate(self):
|
||||
"""Truncate community store."""
|
||||
logger.info("Truncate community metastore")
|
||||
self._meta_store.truncate()
|
||||
|
||||
logger.info("Truncate community summarizer")
|
||||
self._community_summarizer.truncate()
|
||||
|
||||
logger.info("Truncate graph")
|
||||
self._community_store_adapter.graph_store.truncate()
|
||||
|
||||
def drop(self):
|
||||
"""Drop community store."""
|
||||
logger.info("Remove community metastore")
|
||||
self._meta_store.drop()
|
||||
|
||||
logger.info("Remove community summarizer")
|
||||
self._community_summarizer.drop()
|
||||
|
||||
logger.info("Remove graph")
|
||||
self._community_store_adapter.graph_store.drop()
|
30
dbgpt/storage/knowledge_graph/community/factory.py
Normal file
30
dbgpt/storage/knowledge_graph/community/factory.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""CommunityStoreAdapter factory."""
|
||||
import logging
|
||||
|
||||
from dbgpt.storage.graph_store.base import GraphStoreBase
|
||||
from dbgpt.storage.graph_store.tugraph_store import TuGraphStore
|
||||
from dbgpt.storage.knowledge_graph.community.base import CommunityStoreAdapter
|
||||
from dbgpt.storage.knowledge_graph.community.tugraph_adapter import (
|
||||
TuGraphCommunityStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityStoreAdapterFactory:
|
||||
"""Factory for community store adapter."""
|
||||
|
||||
@staticmethod
|
||||
def create(graph_store: GraphStoreBase) -> CommunityStoreAdapter:
|
||||
"""Create a CommunityStoreAdapter instance.
|
||||
|
||||
Args:
|
||||
- graph_store_type: graph store type Memory, TuGraph, Neo4j
|
||||
"""
|
||||
if isinstance(graph_store, TuGraphStore):
|
||||
return TuGraphCommunityStoreAdapter(graph_store)
|
||||
else:
|
||||
raise Exception(
|
||||
"create community store adapter for %s failed",
|
||||
graph_store.__class__.__name__,
|
||||
)
|
52
dbgpt/storage/knowledge_graph/community/tugraph_adapter.py
Normal file
52
dbgpt/storage/knowledge_graph/community/tugraph_adapter.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""TuGraph Community Store Adapter."""
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from dbgpt.storage.graph_store.graph import MemoryGraph
|
||||
from dbgpt.storage.knowledge_graph.community.base import (
|
||||
Community,
|
||||
CommunityStoreAdapter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TuGraphCommunityStoreAdapter(CommunityStoreAdapter):
|
||||
"""TuGraph Community Store Adapter."""
|
||||
|
||||
MAX_HIERARCHY_LEVEL = 3
|
||||
|
||||
async def discover_communities(self, **kwargs) -> List[str]:
|
||||
"""Run community discovery with leiden."""
|
||||
mg = self._graph_store.query(
|
||||
"CALL db.plugin.callPlugin"
|
||||
"('CPP','leiden','{\"leiden_val\":\"_community_id\"}',60.00,false)"
|
||||
)
|
||||
result = mg.get_vertex("json_node").get_prop("description")
|
||||
community_ids = json.loads(result)["community_id_list"]
|
||||
logger.info(f"Discovered {len(community_ids)} communities.")
|
||||
return community_ids
|
||||
|
||||
async def get_community(self, community_id: str) -> Community:
|
||||
"""Get community."""
|
||||
query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n"
|
||||
)
|
||||
edge_query = (
|
||||
f"MATCH (n:{self._graph_store.get_vertex_type()})-"
|
||||
f"[r:{self._graph_store.get_edge_type()}]-"
|
||||
f"(m:{self._graph_store.get_vertex_type()})"
|
||||
f"WHERE n._community_id = '{community_id}' RETURN n,r,m"
|
||||
)
|
||||
|
||||
all_vertex_graph = self._graph_store.aquery(query)
|
||||
all_edge_graph = self._graph_store.aquery(edge_query)
|
||||
all_graph = MemoryGraph()
|
||||
for vertex in all_vertex_graph.vertices():
|
||||
all_graph.upsert_vertex(vertex)
|
||||
for edge in all_edge_graph.edges():
|
||||
all_graph.append_edge(edge)
|
||||
|
||||
return Community(id=community_id, data=all_graph)
|
Reference in New Issue
Block a user