mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 03:20:41 +00:00
feat: Enhance the triplets extraction in the knowledge graph by the batch size (#2091)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Transformer base class."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
@@ -37,6 +38,15 @@ class ExtractorBase(TransformerBase, ABC):
|
||||
async def extract(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Extract results from text."""
|
||||
|
||||
@abstractmethod
|
||||
async def batch_extract(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = 1,
|
||||
limit: Optional[int] = None,
|
||||
) -> List:
|
||||
"""Batch extract results from texts."""
|
||||
|
||||
|
||||
class TranslatorBase(TransformerBase, ABC):
|
||||
"""Translator base class."""
|
||||
|
@@ -1,8 +1,9 @@
|
||||
"""GraphExtractor class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk, LLMClient
|
||||
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
|
||||
@@ -23,35 +24,96 @@ class GraphExtractor(LLMExtractor):
|
||||
self._chunk_history = chunk_history
|
||||
|
||||
config = self._chunk_history.get_config()
|
||||
|
||||
self._vector_space = config.name
|
||||
self._max_chunks_once_load = config.max_chunks_once_load
|
||||
self._max_threads = config.max_threads
|
||||
self._topk = config.topk
|
||||
self._score_threshold = config.score_threshold
|
||||
|
||||
async def extract(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Load similar chunks."""
|
||||
# load similar chunks
|
||||
chunks = await self._chunk_history.asimilar_search_with_scores(
|
||||
text, self._topk, self._score_threshold
|
||||
)
|
||||
history = [
|
||||
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
|
||||
]
|
||||
context = "\n".join(history) if history else ""
|
||||
async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]:
|
||||
"""Load chunk context."""
|
||||
text_context_map: Dict[str, str] = {}
|
||||
|
||||
try:
|
||||
# extract with chunk history
|
||||
return await super()._extract(text, context, limit)
|
||||
for text in texts:
|
||||
# Load similar chunks
|
||||
chunks = await self._chunk_history.asimilar_search_with_scores(
|
||||
text, self._topk, self._score_threshold
|
||||
)
|
||||
history = [
|
||||
f"Section {i + 1}:\n{chunk.content}" for i, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
finally:
|
||||
# save chunk to history
|
||||
# Save chunk to history
|
||||
await self._chunk_history.aload_document_with_limit(
|
||||
[Chunk(content=text, metadata={"relevant_cnt": len(history)})],
|
||||
self._max_chunks_once_load,
|
||||
self._max_threads,
|
||||
)
|
||||
|
||||
# Save chunk context to map
|
||||
context = "\n".join(history) if history else ""
|
||||
text_context_map[text] = context
|
||||
return text_context_map
|
||||
|
||||
async def extract(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Extract graphs from text.
|
||||
|
||||
Suggestion: to extract triplets in batches, call `batch_extract`.
|
||||
"""
|
||||
# Load similar chunks
|
||||
text_context_map = await self.aload_chunk_context([text])
|
||||
context = text_context_map[text]
|
||||
|
||||
# Extract with chunk history
|
||||
return await super()._extract(text, context, limit)
|
||||
|
||||
async def batch_extract(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = 1,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[List[Graph]]:
|
||||
"""Extract graphs from chunks in batches.
|
||||
|
||||
Returns list of graphs in same order as input texts (text <-> graphs).
|
||||
"""
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size >= 1")
|
||||
|
||||
# 1. Load chunk context
|
||||
text_context_map = await self.aload_chunk_context(texts)
|
||||
|
||||
# Pre-allocate results list to maintain order
|
||||
graphs_list: List[List[Graph]] = [None] * len(texts)
|
||||
total_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
for batch_idx in range(total_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min((batch_idx + 1) * batch_size, len(texts))
|
||||
batch_texts = texts[start_idx:end_idx]
|
||||
|
||||
# 2. Create tasks with their original indices
|
||||
extraction_tasks = [
|
||||
(
|
||||
idx,
|
||||
self._extract(text, text_context_map[text], limit),
|
||||
)
|
||||
for idx, text in enumerate(batch_texts, start=start_idx)
|
||||
]
|
||||
|
||||
# 3. Process extraction in parallel while keeping track of indices
|
||||
batch_results = await asyncio.gather(
|
||||
*(task for _, task in extraction_tasks)
|
||||
)
|
||||
|
||||
# 4. Place results in the correct positions
|
||||
for (idx, _), graphs in zip(extraction_tasks, batch_results):
|
||||
graphs_list[idx] = graphs
|
||||
|
||||
assert all(x is not None for x in graphs_list), "All positions should be filled"
|
||||
return graphs_list
|
||||
|
||||
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[Graph]:
|
||||
graph = MemoryGraph()
|
||||
edge_count = 0
|
||||
|
@@ -1,4 +1,6 @@
|
||||
"""TripletExtractor class."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
@@ -22,6 +24,32 @@ class LLMExtractor(ExtractorBase, ABC):
|
||||
"""Extract by LLM."""
|
||||
return await self._extract(text, None, limit)
|
||||
|
||||
async def batch_extract(
|
||||
self,
|
||||
texts: List[str],
|
||||
batch_size: int = 1,
|
||||
limit: Optional[int] = None,
|
||||
) -> List:
|
||||
"""Batch extract by LLM."""
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size >= 1")
|
||||
|
||||
results = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i : i + batch_size]
|
||||
|
||||
# Create tasks for current batch
|
||||
extraction_tasks = [
|
||||
self._extract(text, None, limit) for text in batch_texts
|
||||
]
|
||||
|
||||
# Execute batch concurrently and wait for all to complete
|
||||
batch_results = await asyncio.gather(*extraction_tasks)
|
||||
results.extend(batch_results)
|
||||
|
||||
return results
|
||||
|
||||
async def _extract(
|
||||
self, text: str, history: str = None, limit: Optional[int] = None
|
||||
) -> List:
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""TripletExtractor class."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple
|
||||
@@ -12,7 +13,7 @@ TRIPLET_EXTRACT_PT = (
|
||||
"Some text is provided below. Given the text, "
|
||||
"extract up to knowledge triplets as more as possible "
|
||||
"in the form of (subject, predicate, object).\n"
|
||||
"Avoid stopwords.\n"
|
||||
"Avoid stopwords. The subject, predicate, object can not be none.\n"
|
||||
"---------------------\n"
|
||||
"Example:\n"
|
||||
"Text: Alice is Bob's mother.\n"
|
||||
|
Reference in New Issue
Block a user