mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 11:39:11 +00:00
feat(GraphRAG): Support concurrent community summarization (#2160)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""Define the CommunityStore class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
@@ -27,28 +28,38 @@ class CommunityStore:
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
async def build_communities(self, batch_size: int = 1):
|
||||
"""Discover communities."""
|
||||
community_ids = await self._graph_store_adapter.discover_communities()
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
n_communities = len(community_ids)
|
||||
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
communities.append(community)
|
||||
logger.info(
|
||||
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
||||
for i in range(0, n_communities, batch_size):
|
||||
batch_ids = community_ids[i : i + batch_size]
|
||||
batch_results = await asyncio.gather(
|
||||
*[self._summary_community(cid) for cid in batch_ids]
|
||||
)
|
||||
# filter out None returns
|
||||
communities.extend([c for c in batch_results if c is not None])
|
||||
|
||||
# truncate then save new summaries
|
||||
await self._meta_store.truncate()
|
||||
await self._meta_store.save(communities)
|
||||
|
||||
async def _summary_community(self, community_id: str) -> Optional[Community]:
|
||||
"""Summarize single community."""
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
if community is None or community.data is None:
|
||||
logger.warning(f"Community {community_id} is empty")
|
||||
return None
|
||||
|
||||
graph = community.data.format()
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
logger.info(f"Summarize community {community_id}: {community.summary[:50]}...")
|
||||
return community
|
||||
|
||||
async def search_communities(self, query: str) -> List[Community]:
|
||||
"""Search communities."""
|
||||
return await self._meta_store.search(query)
|
||||
|
@@ -356,7 +356,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
return
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_propertity_schema(
|
||||
def _format_graph_property_schema(
|
||||
name: str,
|
||||
type: str = "STRING",
|
||||
optional: bool = False,
|
||||
@@ -390,9 +390,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for document vertex
|
||||
document_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
|
||||
@@ -400,10 +400,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for chunk vertex
|
||||
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("content", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("content", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
|
||||
@@ -411,10 +411,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for entity vertex
|
||||
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
|
||||
@@ -422,10 +422,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for relation edge
|
||||
edge_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_property_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
|
||||
@@ -433,9 +433,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for include edge
|
||||
include_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
|
||||
@@ -443,9 +443,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for next edge
|
||||
next_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
|
||||
|
Reference in New Issue
Block a user