feat(GraphRAG): Support concurrent community summarization (#2160)

This commit is contained in:
Appointat
2024-11-29 20:51:58 +08:00
committed by GitHub
parent e5ec47145f
commit a14eeb56dd
10 changed files with 1226 additions and 784 deletions

View File

@@ -1,7 +1,8 @@
"""Define the CommunityStore class."""
import asyncio
import logging
from typing import List
from typing import List, Optional
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
@@ -27,28 +28,38 @@ class CommunityStore:
self._community_summarizer = community_summarizer
self._meta_store = BuiltinCommunityMetastore(vector_store)
async def build_communities(self):
async def build_communities(self, batch_size: int = 1):
"""Discover communities."""
community_ids = await self._graph_store_adapter.discover_communities()
# summarize communities
communities = []
for community_id in community_ids:
community = await self._graph_store_adapter.get_community(community_id)
graph = community.data.format()
if not graph:
break
n_communities = len(community_ids)
community.summary = await self._community_summarizer.summarize(graph=graph)
communities.append(community)
logger.info(
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
for i in range(0, n_communities, batch_size):
batch_ids = community_ids[i : i + batch_size]
batch_results = await asyncio.gather(
*[self._summary_community(cid) for cid in batch_ids]
)
# filter out None returns
communities.extend([c for c in batch_results if c is not None])
# truncate then save new summaries
await self._meta_store.truncate()
await self._meta_store.save(communities)
async def _summary_community(self, community_id: str) -> Optional[Community]:
"""Summarize single community."""
community = await self._graph_store_adapter.get_community(community_id)
if community is None or community.data is None:
logger.warning(f"Community {community_id} is empty")
return None
graph = community.data.format()
community.summary = await self._community_summarizer.summarize(graph=graph)
logger.info(f"Summarize community {community_id}: {community.summary[:50]}...")
return community
async def search_communities(self, query: str) -> List[Community]:
"""Search communities."""
return await self._meta_store.search(query)

View File

@@ -356,7 +356,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
return
# Create the graph schema
def _format_graph_propertity_schema(
def _format_graph_property_schema(
name: str,
type: str = "STRING",
optional: bool = False,
@@ -390,9 +390,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for document vertex
document_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
@@ -400,10 +400,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for chunk vertex
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_propertity_schema("content", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("content", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
@@ -411,10 +411,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for entity vertex
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_propertity_schema("description", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("description", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
@@ -422,10 +422,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for relation edge
edge_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
_format_graph_propertity_schema("description", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_chunk_id", "STRING", True, True),
_format_graph_property_schema("description", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
@@ -433,9 +433,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for include edge
include_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("description", "STRING", True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("description", "STRING", True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
@@ -443,9 +443,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for next edge
next_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("description", "STRING", True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("description", "STRING", True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties