feat(GraphRAG): Support concurrent community summarization (#2160)

This commit is contained in:
Appointat
2024-11-29 20:51:58 +08:00
committed by GitHub
parent e5ec47145f
commit a14eeb56dd
10 changed files with 1226 additions and 784 deletions

View File

@@ -73,7 +73,7 @@ class GraphExtractor(LLMExtractor):
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[List[Graph]]:
) -> Optional[List[List[Graph]]]:
"""Extract graphs from chunks in batches.
Returns list of graphs in same order as input texts (text <-> graphs).
@@ -86,11 +86,12 @@ class GraphExtractor(LLMExtractor):
# Pre-allocate results list to maintain order
graphs_list: List[List[Graph]] = [None] * len(texts)
total_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(texts))
n_texts = len(texts)
for batch_idx in range(0, n_texts, batch_size):
start_idx = batch_idx
end_idx = min(start_idx + batch_size, n_texts)
batch_texts = texts[start_idx:end_idx]
# 2. Create tasks with their original indices
@@ -104,11 +105,17 @@ class GraphExtractor(LLMExtractor):
# 3. Process extraction in parallel while keeping track of indices
batch_results = await asyncio.gather(
*(task for _, task in extraction_tasks)
*(task for _, task in extraction_tasks), return_exceptions=True
)
# 4. Place results in the correct positions
for (idx, _), graphs in zip(extraction_tasks, batch_results):
if isinstance(graphs, Exception):
raise RuntimeError(f"Failed to extract graph: {graphs}")
if not isinstance(graphs, list) or not all(
isinstance(g, Graph) for g in graphs
):
raise RuntimeError(f"Invalid graph extraction result: {graphs}")
graphs_list[idx] = graphs
assert all(x is not None for x in graphs_list), "All positions should be filled"