mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
feat(GraphRAG): Support concurrent community summarization (#2160)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
"""Define the CommunityStore class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
@@ -27,28 +28,38 @@ class CommunityStore:
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
async def build_communities(self, batch_size: int = 1):
|
||||
"""Discover communities."""
|
||||
community_ids = await self._graph_store_adapter.discover_communities()
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
n_communities = len(community_ids)
|
||||
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
communities.append(community)
|
||||
logger.info(
|
||||
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
||||
for i in range(0, n_communities, batch_size):
|
||||
batch_ids = community_ids[i : i + batch_size]
|
||||
batch_results = await asyncio.gather(
|
||||
*[self._summary_community(cid) for cid in batch_ids]
|
||||
)
|
||||
# filter out None returns
|
||||
communities.extend([c for c in batch_results if c is not None])
|
||||
|
||||
# truncate then save new summaries
|
||||
await self._meta_store.truncate()
|
||||
await self._meta_store.save(communities)
|
||||
|
||||
async def _summary_community(self, community_id: str) -> Optional[Community]:
|
||||
"""Summarize single community."""
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
if community is None or community.data is None:
|
||||
logger.warning(f"Community {community_id} is empty")
|
||||
return None
|
||||
|
||||
graph = community.data.format()
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
logger.info(f"Summarize community {community_id}: {community.summary[:50]}...")
|
||||
return community
|
||||
|
||||
async def search_communities(self, query: str) -> List[Community]:
|
||||
"""Search communities."""
|
||||
return await self._meta_store.search(query)
|
||||
|
@@ -356,7 +356,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
return
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_propertity_schema(
|
||||
def _format_graph_property_schema(
|
||||
name: str,
|
||||
type: str = "STRING",
|
||||
optional: bool = False,
|
||||
@@ -390,9 +390,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for document vertex
|
||||
document_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
|
||||
@@ -400,10 +400,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for chunk vertex
|
||||
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("content", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("content", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
|
||||
@@ -411,10 +411,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for entity vertex
|
||||
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
|
||||
@@ -422,10 +422,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for relation edge
|
||||
edge_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_property_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
|
||||
@@ -433,9 +433,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for include edge
|
||||
include_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
|
||||
@@ -443,9 +443,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for next edge
|
||||
next_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
|
||||
|
@@ -38,8 +38,7 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The password of vector store, "
|
||||
"if not set, will use the default password."
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
extract_topk: int = Field(
|
||||
@@ -75,6 +74,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
default=20,
|
||||
description="Batch size of triplets extraction from the text",
|
||||
)
|
||||
community_summary_batch_size: int = Field(
|
||||
default=20,
|
||||
description="Batch size of parallel community building process",
|
||||
)
|
||||
|
||||
|
||||
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
@@ -130,6 +133,12 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
config.knowledge_graph_extraction_batch_size,
|
||||
)
|
||||
)
|
||||
self._community_summary_batch_size = int(
|
||||
os.getenv(
|
||||
"COMMUNITY_SUMMARY_BATCH_SIZE",
|
||||
config.community_summary_batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
@@ -177,9 +186,12 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist graph from the document file."""
|
||||
|
||||
await self._aload_document_graph(chunks)
|
||||
await self._aload_triplet_graph(chunks)
|
||||
await self._community_store.build_communities()
|
||||
await self._community_store.build_communities(
|
||||
batch_size=self._community_summary_batch_size
|
||||
)
|
||||
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
@@ -230,6 +242,8 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
[chunk.content for chunk in chunks],
|
||||
batch_size=self._triplet_extraction_batch_size,
|
||||
)
|
||||
if not graphs_list:
|
||||
raise ValueError("No graphs extracted from the chunks")
|
||||
|
||||
# Upsert the graphs into the graph store
|
||||
for idx, graphs in enumerate(graphs_list):
|
||||
|
Reference in New Issue
Block a user