feat: Enhance the triplets extraction in the knowledge graph by the batch size (#2091)

This commit is contained in:
Appointat
2024-11-05 14:01:18 +08:00
committed by GitHub
parent b4ce217ded
commit 25d47ce343
10 changed files with 360 additions and 242 deletions

View File

@@ -1,4 +1,5 @@
"""Transformer base class."""
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
@@ -37,6 +38,15 @@ class ExtractorBase(TransformerBase, ABC):
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract results from text."""
@abstractmethod
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract results from texts."""
class TranslatorBase(TransformerBase, ABC):
"""Translator base class."""

View File

@@ -1,8 +1,9 @@
"""GraphExtractor class."""
import asyncio
import logging
import re
from typing import List, Optional
from typing import Dict, List, Optional
from dbgpt.core import Chunk, LLMClient
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
@@ -23,35 +24,96 @@ class GraphExtractor(LLMExtractor):
self._chunk_history = chunk_history
config = self._chunk_history.get_config()
self._vector_space = config.name
self._max_chunks_once_load = config.max_chunks_once_load
self._max_threads = config.max_threads
self._topk = config.topk
self._score_threshold = config.score_threshold
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Load similar chunks."""
# load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
context = "\n".join(history) if history else ""
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
"""Load chunk context."""
text_context_map: Dict[str, str] = {}
try:
# extract with chunk history
return await super()._extract(text, context, limit)
for text in texts:
# Load similar chunks
chunks = await self._chunk_history.asimilar_search_with_scores(
text, self._topk, self._score_threshold
)
history = [
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
]
finally:
# save chunk to history
# Save chunk to history
await self._chunk_history.aload_document_with_limit(
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
self._max_chunks_once_load,
self._max_threads,
)
# Save chunk context to map
context = "\n".join(history) if history else ""
text_context_map[text] = context
return text_context_map
async def extract(self, text: str, limit: Optional[int] = None) -> List:
"""Extract graphs from text.
Suggestion: to extract triplets in batches, call `batch_extract`.
"""
# Load similar chunks
text_context_map = await self.aload_chunk_context([text])
context = text_context_map[text]
# Extract with chunk history
return await super()._extract(text, context, limit)
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List[List[Graph]]:
"""Extract graphs from chunks in batches.
Returns list of graphs in same order as input texts (text <-> graphs).
"""
if batch_size < 1:
raise ValueError("batch_size >= 1")
# 1. Load chunk context
text_context_map = await self.aload_chunk_context(texts)
# Pre-allocate results list to maintain order
graphs_list: List[List[Graph]] = [None] * len(texts)
total_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx in range(total_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
# 2. Create tasks with their original indices
extraction_tasks = [
(
idx,
self._extract(text, text_context_map[text], limit),
)
for idx, text in enumerate(batch_texts, start=start_idx)
]
# 3. Process extraction in parallel while keeping track of indices
batch_results = await asyncio.gather(
*(task for _, task in extraction_tasks)
)
# 4. Place results in the correct positions
for (idx, _), graphs in zip(extraction_tasks, batch_results):
graphs_list[idx] = graphs
assert all(x is not None for x in graphs_list), "All positions should be filled"
return graphs_list
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
graph = MemoryGraph()
edge_count = 0

View File

@@ -1,4 +1,6 @@
"""TripletExtractor class."""
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional
@@ -22,6 +24,32 @@ class LLMExtractor(ExtractorBase, ABC):
"""Extract by LLM."""
return await self._extract(text, None, limit)
async def batch_extract(
self,
texts: List[str],
batch_size: int = 1,
limit: Optional[int] = None,
) -> List:
"""Batch extract by LLM."""
if batch_size < 1:
raise ValueError("batch_size >= 1")
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
# Create tasks for current batch
extraction_tasks = [
self._extract(text, None, limit) for text in batch_texts
]
# Execute batch concurrently and wait for all to complete
batch_results = await asyncio.gather(*extraction_tasks)
results.extend(batch_results)
return results
async def _extract(
self, text: str, history: str = None, limit: Optional[int] = None
) -> List:

View File

@@ -1,4 +1,5 @@
"""TripletExtractor class."""
import logging
import re
from typing import Any, List, Optional, Tuple
@@ -12,7 +13,7 @@ TRIPLET_EXTRACT_PT = (
"Some text is provided below. Given the text, "
"extract up to knowledge triplets as more as possible "
"in the form of (subject, predicate, object).\n"
"Avoid stopwords.\n"
"Avoid stopwords. The subject, predicate, object can not be none.\n"
"---------------------\n"
"Example:\n"
"Text: Alice is Bob's mother.\n"