mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
feat(GraphRAG): Support concurrent community summarization (#2160)
This commit is contained in:
@@ -73,7 +73,7 @@ class GraphExtractor(LLMExtractor):
|
||||
texts: List[str],
|
||||
batch_size: int = 1,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[List[Graph]]:
|
||||
) -> Optional[List[List[Graph]]]:
|
||||
"""Extract graphs from chunks in batches.
|
||||
|
||||
Returns list of graphs in same order as input texts (text <-> graphs).
|
||||
@@ -86,11 +86,12 @@ class GraphExtractor(LLMExtractor):
|
||||
|
||||
# Pre-allocate results list to maintain order
|
||||
graphs_list: List[List[Graph]] = [None] * len(texts)
|
||||
total_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
for batch_idx in range(total_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min((batch_idx + 1) * batch_size, len(texts))
|
||||
n_texts = len(texts)
|
||||
|
||||
for batch_idx in range(0, n_texts, batch_size):
|
||||
start_idx = batch_idx
|
||||
end_idx = min(start_idx + batch_size, n_texts)
|
||||
batch_texts = texts[start_idx:end_idx]
|
||||
|
||||
# 2. Create tasks with their original indices
|
||||
@@ -104,11 +105,17 @@ class GraphExtractor(LLMExtractor):
|
||||
|
||||
# 3. Process extraction in parallel while keeping track of indices
|
||||
batch_results = await asyncio.gather(
|
||||
*(task for _, task in extraction_tasks)
|
||||
*(task for _, task in extraction_tasks), return_exceptions=True
|
||||
)
|
||||
|
||||
# 4. Place results in the correct positions
|
||||
for (idx, _), graphs in zip(extraction_tasks, batch_results):
|
||||
if isinstance(graphs, Exception):
|
||||
raise RuntimeError(f"Failed to extract graph: {graphs}")
|
||||
if not isinstance(graphs, list) or not all(
|
||||
isinstance(g, Graph) for g in graphs
|
||||
):
|
||||
raise RuntimeError(f"Invalid graph extraction result: {graphs}")
|
||||
graphs_list[idx] = graphs
|
||||
|
||||
assert all(x is not None for x in graphs_list), "All positions should be filled"
|
||||
|
@@ -1,7 +1,8 @@
|
||||
"""Define the CommunityStore class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.rag.transformer.community_summarizer import CommunitySummarizer
|
||||
from dbgpt.storage.knowledge_graph.community.base import Community, GraphStoreAdapter
|
||||
@@ -27,28 +28,38 @@ class CommunityStore:
|
||||
self._community_summarizer = community_summarizer
|
||||
self._meta_store = BuiltinCommunityMetastore(vector_store)
|
||||
|
||||
async def build_communities(self):
|
||||
async def build_communities(self, batch_size: int = 1):
|
||||
"""Discover communities."""
|
||||
community_ids = await self._graph_store_adapter.discover_communities()
|
||||
|
||||
# summarize communities
|
||||
communities = []
|
||||
for community_id in community_ids:
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
graph = community.data.format()
|
||||
if not graph:
|
||||
break
|
||||
n_communities = len(community_ids)
|
||||
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
communities.append(community)
|
||||
logger.info(
|
||||
f"Summarize community {community_id}: " f"{community.summary[:50]}..."
|
||||
for i in range(0, n_communities, batch_size):
|
||||
batch_ids = community_ids[i : i + batch_size]
|
||||
batch_results = await asyncio.gather(
|
||||
*[self._summary_community(cid) for cid in batch_ids]
|
||||
)
|
||||
# filter out None returns
|
||||
communities.extend([c for c in batch_results if c is not None])
|
||||
|
||||
# truncate then save new summaries
|
||||
await self._meta_store.truncate()
|
||||
await self._meta_store.save(communities)
|
||||
|
||||
async def _summary_community(self, community_id: str) -> Optional[Community]:
|
||||
"""Summarize single community."""
|
||||
community = await self._graph_store_adapter.get_community(community_id)
|
||||
if community is None or community.data is None:
|
||||
logger.warning(f"Community {community_id} is empty")
|
||||
return None
|
||||
|
||||
graph = community.data.format()
|
||||
community.summary = await self._community_summarizer.summarize(graph=graph)
|
||||
logger.info(f"Summarize community {community_id}: {community.summary[:50]}...")
|
||||
return community
|
||||
|
||||
async def search_communities(self, query: str) -> List[Community]:
|
||||
"""Search communities."""
|
||||
return await self._meta_store.search(query)
|
||||
|
@@ -356,7 +356,7 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
return
|
||||
|
||||
# Create the graph schema
|
||||
def _format_graph_propertity_schema(
|
||||
def _format_graph_property_schema(
|
||||
name: str,
|
||||
type: str = "STRING",
|
||||
optional: bool = False,
|
||||
@@ -390,9 +390,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for document vertex
|
||||
document_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.DOCUMENT, graph_properties=document_proerties
|
||||
@@ -400,10 +400,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for chunk vertex
|
||||
chunk_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("content", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("content", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.CHUNK, graph_properties=chunk_proerties
|
||||
@@ -411,10 +411,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for entity vertex
|
||||
vertex_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_community_id", "STRING", True, True),
|
||||
_format_graph_property_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.ENTITY, graph_properties=vertex_proerties
|
||||
@@ -422,10 +422,10 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for relation edge
|
||||
edge_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_propertity_schema("description", "STRING", True, True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("_chunk_id", "STRING", True, True),
|
||||
_format_graph_property_schema("description", "STRING", True, True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.RELATION, graph_properties=edge_proerties
|
||||
@@ -433,9 +433,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for include edge
|
||||
include_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.INCLUDE, graph_properties=include_proerties
|
||||
@@ -443,9 +443,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
|
||||
# Create the graph label for next edge
|
||||
next_proerties: List[Dict[str, Union[str, bool]]] = [
|
||||
_format_graph_propertity_schema("id", "STRING", False),
|
||||
_format_graph_propertity_schema("name", "STRING", False),
|
||||
_format_graph_propertity_schema("description", "STRING", True),
|
||||
_format_graph_property_schema("id", "STRING", False),
|
||||
_format_graph_property_schema("name", "STRING", False),
|
||||
_format_graph_property_schema("description", "STRING", True),
|
||||
]
|
||||
self.create_graph_label(
|
||||
graph_elem_type=GraphElemType.NEXT, graph_properties=next_proerties
|
||||
|
@@ -38,8 +38,7 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
password: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The password of vector store, "
|
||||
"if not set, will use the default password."
|
||||
"The password of vector store, if not set, will use the default password."
|
||||
),
|
||||
)
|
||||
extract_topk: int = Field(
|
||||
@@ -75,6 +74,10 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
default=20,
|
||||
description="Batch size of triplets extraction from the text",
|
||||
)
|
||||
community_summary_batch_size: int = Field(
|
||||
default=20,
|
||||
description="Batch size of parallel community building process",
|
||||
)
|
||||
|
||||
|
||||
class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
@@ -130,6 +133,12 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
config.knowledge_graph_extraction_batch_size,
|
||||
)
|
||||
)
|
||||
self._community_summary_batch_size = int(
|
||||
os.getenv(
|
||||
"COMMUNITY_SUMMARY_BATCH_SIZE",
|
||||
config.community_summary_batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
def extractor_configure(name: str, cfg: VectorStoreConfig):
|
||||
cfg.name = name
|
||||
@@ -177,9 +186,12 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Extract and persist graph from the document file."""
|
||||
|
||||
await self._aload_document_graph(chunks)
|
||||
await self._aload_triplet_graph(chunks)
|
||||
await self._community_store.build_communities()
|
||||
await self._community_store.build_communities(
|
||||
batch_size=self._community_summary_batch_size
|
||||
)
|
||||
|
||||
return [chunk.chunk_id for chunk in chunks]
|
||||
|
||||
@@ -230,6 +242,8 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph):
|
||||
[chunk.content for chunk in chunks],
|
||||
batch_size=self._triplet_extraction_batch_size,
|
||||
)
|
||||
if not graphs_list:
|
||||
raise ValueError("No graphs extracted from the chunks")
|
||||
|
||||
# Upsert the graphs into the graph store
|
||||
for idx, graphs in enumerate(graphs_list):
|
||||
|
Reference in New Issue
Block a user