feat(GraphRAG): Support concurrent community summarization (#2160)

This commit is contained in:
Appointat
2024-11-29 20:51:58 +08:00
committed by GitHub
parent e5ec47145f
commit a14eeb56dd
10 changed files with 1226 additions and 784 deletions

View File

@@ -1,7 +1,8 @@
"""Define the CommunityStore class."""
import asyncio
import logging
from typing import List
from typing import List, Optional
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
@@ -27,28 +28,38 @@ class CommunityStore:
self._community_summarizer = community_summarizer
self._meta_store = BuiltinCommunityMetastore(vector_store)
async def build_communities(self):
async def build_communities(self, batch_size: int = 1):
"""Discover communities."""
community_ids = await self._graph_store_adapter.discover_communities()
# summarize communities
communities = []
for community_id in community_ids:
community = await self._graph_store_adapter.get_community(community_id)
graph = community.data.format()
if not graph:
break
n_communities = len(community_ids)
community.summary = await self._community_summarizer.summarize(graph=graph)
communities.append(community)
logger.info(
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
for i in range(0, n_communities, batch_size):
batch_ids = community_ids[i : i + batch_size]
batch_results = await asyncio.gather(
*[self._summary_community(cid) for cid in batch_ids]
)
# filter out None returns
communities.extend([c for c in batch_results if c is not None])
# truncate then save new summaries
await self._meta_store.truncate()
await self._meta_store.save(communities)
async def _summary_community(self, community_id: str) -> Optional[Community]:
"""Summarize single community."""
community = await self._graph_store_adapter.get_community(community_id)
if community is None or community.data is None:
logger.warning(f"Community {community_id} is empty")
return None
graph = community.data.format()
community.summary = await self._community_summarizer.summarize(graph=graph)
logger.info(f"Summarize community {community_id}: {community.summary[:50]}...")
return community
async def search_communities(self, query: str) -> List[Community]:
"""Search communities."""
return await self._meta_store.search(query)

View File

@@ -356,7 +356,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
return
# Create the graph schema
def _format_graph_propertity_schema(
def _format_graph_property_schema(
name: str,
type: str = "STRING",
optional: bool = False,
@@ -390,9 +390,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for document vertex
document_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
@@ -400,10 +400,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for chunk vertex
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_propertity_schema("content", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("content", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
@@ -411,10 +411,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for entity vertex
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_community_id", "STRING", True, True),
_format_graph_propertity_schema("description", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_community_id", "STRING", True, True),
_format_graph_property_schema("description", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
@@ -422,10 +422,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for relation edge
edge_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
_format_graph_propertity_schema("description", "STRING", True, True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("_chunk_id", "STRING", True, True),
_format_graph_property_schema("description", "STRING", True, True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
@@ -433,9 +433,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for include edge
include_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("description", "STRING", True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("description", "STRING", True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
@@ -443,9 +443,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
# Create the graph label for next edge
next_proerties: List[Dict[str, Union[str, bool]]] = [
_format_graph_propertity_schema("id", "STRING", False),
_format_graph_propertity_schema("name", "STRING", False),
_format_graph_propertity_schema("description", "STRING", True),
_format_graph_property_schema("id", "STRING", False),
_format_graph_property_schema("name", "STRING", False),
_format_graph_property_schema("description", "STRING", True),
]
self.create_graph_label(
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties

View File

@@ -38,8 +38,7 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
password: Optional[str] = Field(
default=None,
description=(
"The password of vector store, "
"if not set, will use the default password."
"The password of vector store, if not set, will use the default password."
),
)
extract_topk: int = Field(
@@ -75,6 +74,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
default=20,
description="Batch size of triplets extraction from the text",
)
community_summary_batch_size: int = Field(
default=20,
description="Batch size of parallel community building process",
)
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
@@ -130,6 +133,12 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
config.knowledge_graph_extraction_batch_size,
)
)
self._community_summary_batch_size = int(
os.getenv(
"COMMUNITY_SUMMARY_BATCH_SIZE",
config.community_summary_batch_size,
)
)
def extractor_configure(name: str, cfg: VectorStoreConfig):
cfg.name = name
@@ -177,9 +186,12 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Extract and persist graph from the document file."""
await self._aload_document_graph(chunks)
await self._aload_triplet_graph(chunks)
await self._community_store.build_communities()
await self._community_store.build_communities(
batch_size=self._community_summary_batch_size
)
return [chunk.chunk_id for chunk in chunks]
@@ -230,6 +242,8 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
[chunk.content for chunk in chunks],
batch_size=self._triplet_extraction_batch_size,
)
if not graphs_list:
raise ValueError("No graphs extracted from the chunks")
# Upsert the graphs into the graph store
for idx, graphs in enumerate(graphs_list):