mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 20:01:46 +00:00
feat(RAG):add tree retriever based on document directory level (#2669)
This commit is contained in:
parent
3a00aca113
commit
421004a1d8
36
examples/rag/doc_tree_retriever_example.py
Normal file
36
examples/rag/doc_tree_retriever_example.py
Normal file
@ -0,0 +1,36 @@
|
||||
import asyncio
|
||||
|
||||
from dbgpt.model.proxy import DeepseekLLMClient
|
||||
from dbgpt.rag.knowledge.base import ChunkStrategy
|
||||
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||
from dbgpt_ext.rag import ChunkParameters
|
||||
from dbgpt_ext.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt_ext.rag.retriever.doc_tree import DocTreeRetriever
|
||||
|
||||
|
||||
async def main():
|
||||
knowledge = KnowledgeFactory.from_file_path("../../docs/docs/awel/awel.md")
|
||||
chunk_parameters = ChunkParameters(
|
||||
chunk_strategy=ChunkStrategy.CHUNK_BY_MARKDOWN_HEADER.name
|
||||
)
|
||||
docs = knowledge.load()
|
||||
docs = knowledge.extract(docs, chunk_parameters)
|
||||
llm_client = DeepseekLLMClient(api_key="your_api_key")
|
||||
keyword_extractor = KeywordExtractor(
|
||||
llm_client=llm_client, model_name="deepseek-chat"
|
||||
)
|
||||
# doc tree retriever retriever
|
||||
retriever = DocTreeRetriever(
|
||||
docs=docs,
|
||||
top_k=10,
|
||||
keywords_extractor=keyword_extractor,
|
||||
with_content=False,
|
||||
)
|
||||
tree_index = retriever._tree_indexes[0]
|
||||
nodes = await retriever.aretrieve("Introduce awel Operators")
|
||||
for node in nodes:
|
||||
tree_index.display_tree(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -16,6 +16,10 @@ class Document(BaseModel):
|
||||
default_factory=dict,
|
||||
description="metadata fields",
|
||||
)
|
||||
chunks: List["Chunk"] = Field(
|
||||
default_factory=list,
|
||||
description="list of chunks",
|
||||
)
|
||||
|
||||
def set_content(self, content: str) -> None:
|
||||
"""Set document content."""
|
||||
|
@ -164,7 +164,11 @@ class Knowledge(ABC):
|
||||
documents = self._load()
|
||||
return self._postprocess(documents)
|
||||
|
||||
def extract(self, documents: List[Document]) -> List[Document]:
|
||||
def extract(
|
||||
self,
|
||||
documents: List[Document],
|
||||
chunk_parameters: Optional[Any] = None,
|
||||
) -> List[Document]:
|
||||
"""Extract knowledge from text."""
|
||||
return documents
|
||||
|
||||
|
@ -231,3 +231,47 @@ class IndexStoreBase(ABC):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self.similar_search_with_scores, query, topk, score_threshold, filters
|
||||
)
|
||||
|
||||
def full_text_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Full text search in index database.
|
||||
|
||||
Args:
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
filters(Optional[MetadataFilters]): metadata filters.
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Full text search is not supported in this index store."
|
||||
)
|
||||
|
||||
async def afull_text_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search in index database.
|
||||
|
||||
Args:
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
filters(Optional[MetadataFilters]): metadata filters.
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self.full_text_search, text, topk, filters
|
||||
)
|
||||
|
||||
def is_support_full_text_search(self) -> bool:
|
||||
"""Support full text search.
|
||||
|
||||
Args:
|
||||
collection_name(str): collection name.
|
||||
Return:
|
||||
bool: The similar documents.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Full text search is not supported in this index store."
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ from dbgpt.rag.knowledge.base import (
|
||||
Knowledge,
|
||||
KnowledgeType,
|
||||
)
|
||||
from dbgpt_ext.rag import ChunkParameters
|
||||
|
||||
|
||||
class MarkdownKnowledge(Knowledge):
|
||||
@ -49,13 +50,30 @@ class MarkdownKnowledge(Knowledge):
|
||||
raise ValueError("file path is required")
|
||||
with open(self._path, encoding=self._encoding, errors="ignore") as f:
|
||||
markdown_text = f.read()
|
||||
metadata = {"source": self._path}
|
||||
metadata = {
|
||||
"source": self._path,
|
||||
"title": self._path.rsplit("/", 1)[-1],
|
||||
}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
documents = [Document(content=markdown_text, metadata=metadata)]
|
||||
return documents
|
||||
return [Document.langchain2doc(lc_document) for lc_document in documents]
|
||||
|
||||
def extract(
|
||||
self,
|
||||
documents: List[Document],
|
||||
chunk_parameter: Optional[ChunkParameters] = None,
|
||||
) -> List[Document]:
|
||||
"""Extract knowledge from text."""
|
||||
from dbgpt_ext.rag.chunk_manager import ChunkManager
|
||||
|
||||
chunk_manager = ChunkManager(knowledge=self, chunk_parameter=chunk_parameter)
|
||||
chunks = chunk_manager.split(documents)
|
||||
for document in documents:
|
||||
document.chunks = chunks
|
||||
return documents
|
||||
|
||||
@classmethod
|
||||
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
|
||||
"""Return support chunk strategy."""
|
||||
|
364
packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/doc_tree.py
Normal file
364
packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/doc_tree.py
Normal file
@ -0,0 +1,364 @@
|
||||
"""Tree-based document retriever."""
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Document
|
||||
from dbgpt.rag.retriever import BaseRetriever, DefaultRanker, QueryRewrite, Ranker
|
||||
from dbgpt.rag.transformer.base import ExtractorBase
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RETRIEVER_NAME = "doc_tree_retriever"
|
||||
TITLE = "title"
|
||||
HEADER1 = "Header1"
|
||||
HEADER2 = "Header2"
|
||||
HEADER3 = "Header3"
|
||||
HEADER4 = "Header4"
|
||||
HEADER5 = "Header5"
|
||||
HEADER6 = "Header6"
|
||||
|
||||
|
||||
class TreeNode:
|
||||
"""TreeNode class to represent a node in the document tree."""
|
||||
|
||||
def __init__(
|
||||
self, node_id: str, level_text: str, level: int, content: Optional[str] = None
|
||||
):
|
||||
"""Initialize a TreeNode."""
|
||||
self.node_id = node_id
|
||||
self.level = level # 0: title, 1: header1, 2: header2, 3: header3
|
||||
self.level_text = level_text # 0: title, 1: header1, 2: header2, 3: header3
|
||||
self.children = []
|
||||
self.content = content
|
||||
self.retriever = RETRIEVER_NAME
|
||||
|
||||
def add_child(self, child_node):
|
||||
"""Add a child node to the current node."""
|
||||
self.children.append(child_node)
|
||||
|
||||
|
||||
class DocTreeIndex:
|
||||
def __init__(self):
|
||||
"""Initialize the document tree index."""
|
||||
self.root = TreeNode("root_id", "Root", -1)
|
||||
|
||||
def add_nodes(
|
||||
self,
|
||||
node_id: str,
|
||||
title: str,
|
||||
header1: Optional[str] = None,
|
||||
header2: Optional[str] = None,
|
||||
header3: Optional[str] = None,
|
||||
header4: Optional[str] = None,
|
||||
header5: Optional[str] = None,
|
||||
header6: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
):
|
||||
"""Add nodes to the document tree.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node.
|
||||
title (str): The title of the document.
|
||||
header1 (Optional[str]): The first header.
|
||||
header2 (Optional[str]): The second header.
|
||||
header3 (Optional[str]): The third header.
|
||||
header4 (Optional[str]): The fourth header.
|
||||
header5 (Optional[str]): The fifth header.
|
||||
header6 (Optional[str]): The sixth header.
|
||||
content (Optional[str]): The content of the node.
|
||||
"""
|
||||
# Assuming titles is a dictionary containing title and headers
|
||||
title_node = None
|
||||
if title:
|
||||
title_nodes = self.get_node_by_level(0)
|
||||
if not title_nodes:
|
||||
# If title already exists, do not add it again
|
||||
title_node = TreeNode(node_id, title, 0, content)
|
||||
self.root.add_child(title_node)
|
||||
else:
|
||||
title_node = title_nodes[0]
|
||||
current_node = title_node
|
||||
headers = [header1, header2, header3, header4, header5, header6]
|
||||
for level, header in enumerate(headers, start=1):
|
||||
if header:
|
||||
header_nodes = self.get_node_by_level_text(header)
|
||||
if header_nodes:
|
||||
# If header already exists, do not add it again
|
||||
current_node = header_nodes[0]
|
||||
continue
|
||||
new_header_node = TreeNode(node_id, header, level, content)
|
||||
current_node.add_child(new_header_node)
|
||||
current_node = new_header_node
|
||||
|
||||
def add_nodes_with_content(
|
||||
self,
|
||||
node_id: str,
|
||||
title: str,
|
||||
header1: Optional[str] = None,
|
||||
header2: Optional[str] = None,
|
||||
header3: Optional[str] = None,
|
||||
header4: Optional[str] = None,
|
||||
header5: Optional[str] = None,
|
||||
header6: Optional[str] = None,
|
||||
):
|
||||
"""Add nodes to the document tree.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node.
|
||||
title (str): The title of the document.
|
||||
header1 (Optional[str]): The first header.
|
||||
header2 (Optional[str]): The second header.
|
||||
header3 (Optional[str]): The third header.
|
||||
header4 (Optional[str]): The fourth header.
|
||||
header5 (Optional[str]): The fifth header.
|
||||
header6 (Optional[str]): The sixth header.
|
||||
"""
|
||||
# Assuming titles is a dictionary containing title and headers
|
||||
title_node = None
|
||||
if title:
|
||||
title_nodes = self.get_node_by_level(0)
|
||||
if not title_nodes:
|
||||
# If title already exists, do not add it again
|
||||
title_node = TreeNode(node_id, title, 0)
|
||||
self.root.add_child(title_node)
|
||||
else:
|
||||
title_node = title_nodes[0]
|
||||
current_node = title_node
|
||||
headers = [header1, header2, header3, header4, header5, header6]
|
||||
for level, header in enumerate(headers, start=1):
|
||||
if header:
|
||||
header_nodes = self.get_node_by_level_text(header)
|
||||
if header_nodes:
|
||||
# If header already exists, do not add it again
|
||||
current_node = header_nodes[0]
|
||||
continue
|
||||
new_header_node = TreeNode(node_id, header, level)
|
||||
current_node.add_child(new_header_node)
|
||||
current_node = new_header_node
|
||||
|
||||
def get_node_by_level(self, level):
|
||||
"""Get nodes by level."""
|
||||
# Traverse the tree to find nodes at the specified level
|
||||
result = []
|
||||
self._traverse(self.root, level, result)
|
||||
return result
|
||||
|
||||
def get_node_by_level_text(self, content):
|
||||
"""Get nodes by level."""
|
||||
# Traverse the tree to find nodes at the specified level
|
||||
result = []
|
||||
self._traverse_by_level_text(self.root, content, result)
|
||||
return result
|
||||
|
||||
def get_all_children(self, node):
|
||||
"""get all children of the node."""
|
||||
# Get all children of the current node
|
||||
result = []
|
||||
self._traverse(node, node.level, result)
|
||||
return result
|
||||
|
||||
def display_tree(self, node: TreeNode, prefix: Optional[str] = ""):
|
||||
"""Recursive function to display the directory structure with visual cues."""
|
||||
# Print the current node title with prefix
|
||||
if node.content:
|
||||
print(
|
||||
f"{prefix}├── {node.level_text} (node_id: {node.node_id}) "
|
||||
f"(content: {node.content})"
|
||||
)
|
||||
else:
|
||||
print(f"{prefix}├── {node.level_text} (node_id: {node.node_id})")
|
||||
|
||||
# Update prefix for children
|
||||
new_prefix = prefix + "│ " # Extend the prefix for child nodes
|
||||
for i, child in enumerate(node.children):
|
||||
if i == len(node.children) - 1: # If it's the last child
|
||||
new_prefix_child = prefix + "└── "
|
||||
else:
|
||||
new_prefix_child = new_prefix
|
||||
|
||||
# Recursive call for the next child node
|
||||
self.display_tree(child, new_prefix_child)
|
||||
|
||||
def _traverse(self, node, level, result):
|
||||
"""Traverse the tree to find nodes at the specified level."""
|
||||
# If the current node's level matches the specified level, add it to the result
|
||||
if node.level == level:
|
||||
result.append(node)
|
||||
# Recursively traverse child nodes
|
||||
for child in node.children:
|
||||
self._traverse(child, level, result)
|
||||
|
||||
def _traverse_by_level_text(self, node, level_text, result):
|
||||
"""Traverse the tree to find nodes at the specified level."""
|
||||
# If the current node's level matches the specified level, add it to the result
|
||||
if node.level_text == level_text:
|
||||
result.append(node)
|
||||
# Recursively traverse child nodes
|
||||
for child in node.children:
|
||||
self._traverse_by_level_text(child, level_text, result)
|
||||
|
||||
def search_keywords(self, node, keyword) -> Optional[TreeNode]:
|
||||
# Check if the keyword matches the current node title
|
||||
if keyword.lower() == node.level_text.lower():
|
||||
logger.info(f"DocTreeIndex Match found in: {node.level_text}")
|
||||
return node
|
||||
# Recursively search in child nodes
|
||||
for child in node.children:
|
||||
result = self.search_keywords(child, keyword)
|
||||
if result:
|
||||
return result
|
||||
# Check if the keyword matches any of the child nodes
|
||||
# If no match, continue to search in all children
|
||||
return None
|
||||
|
||||
|
||||
class DocTreeRetriever(BaseRetriever):
|
||||
"""Doc Tree retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docs: List[Document] = None,
|
||||
top_k: Optional[int] = 10,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Optional[Ranker] = None,
|
||||
keywords_extractor: Optional[ExtractorBase] = None,
|
||||
with_content: bool = False,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
"""Create DocTreeRetriever.
|
||||
|
||||
Args:
|
||||
docs (List[Document]): List of documents to initialize the tree with.
|
||||
top_k (int): top k
|
||||
query_rewrite (Optional[QueryRewrite]): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
keywords_extractor (Optional[ExtractorBase]): keywords extractor
|
||||
with_content: bool: whether to include content
|
||||
executor (Optional[Executor]): executor
|
||||
|
||||
Returns:
|
||||
DocTreeRetriever: BM25 retriever
|
||||
"""
|
||||
super().__init__()
|
||||
self._top_k = top_k
|
||||
self._query_rewrite = query_rewrite
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
self._keywords_extractor = keywords_extractor
|
||||
self._with_content = with_content
|
||||
self._tree_indexes = self._initialize_doc_tree(docs)
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
|
||||
def get_tree_indexes(self):
|
||||
"""Get the tree indexes."""
|
||||
return self._tree_indexes
|
||||
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[TreeNode]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
raise NotImplementedError("DocTreeRetriever does not support retrieval.")
|
||||
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[TreeNode]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
raise NotImplementedError("DocTreeRetriever does not support score retrieval.")
|
||||
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[TreeNode]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text.
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
keywords = [query]
|
||||
if self._keywords_extractor:
|
||||
keywords = await self._keywords_extractor.extract(query)
|
||||
all_nodes = []
|
||||
for keyword in keywords:
|
||||
for tree_index in self._tree_indexes:
|
||||
retrieve_node = tree_index.search_keywords(tree_index.root, keyword)
|
||||
if retrieve_node:
|
||||
# If a match is found, return the corresponding chunks
|
||||
all_nodes.append(retrieve_node)
|
||||
return all_nodes
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[TreeNode]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
def _initialize_doc_tree(self, docs: List[Document]):
|
||||
"""Initialize the document tree with docs.
|
||||
|
||||
Args:
|
||||
docs (List[Document]): List of docs to initialize the tree with.
|
||||
"""
|
||||
tree_indexes = []
|
||||
for doc in docs:
|
||||
tree_index = DocTreeIndex()
|
||||
for chunk in doc.chunks:
|
||||
if not chunk.metadata.get(TITLE):
|
||||
continue
|
||||
if not self._with_content:
|
||||
tree_index.add_nodes(
|
||||
node_id=chunk.chunk_id,
|
||||
title=chunk.metadata[TITLE],
|
||||
header1=chunk.metadata.get(HEADER1),
|
||||
header2=chunk.metadata.get(HEADER2),
|
||||
header3=chunk.metadata.get(HEADER3),
|
||||
header4=chunk.metadata.get(HEADER4),
|
||||
header5=chunk.metadata.get(HEADER5),
|
||||
header6=chunk.metadata.get(HEADER6),
|
||||
)
|
||||
else:
|
||||
tree_index.add_nodes(
|
||||
node_id=chunk.chunk_id,
|
||||
title=chunk.metadata[TITLE],
|
||||
header1=chunk.metadata.get(HEADER1),
|
||||
header2=chunk.metadata.get(HEADER2),
|
||||
header3=chunk.metadata.get(HEADER3),
|
||||
header4=chunk.metadata.get(HEADER4),
|
||||
header5=chunk.metadata.get(HEADER5),
|
||||
header6=chunk.metadata.get(HEADER6),
|
||||
content=chunk.content,
|
||||
)
|
||||
tree_indexes.append(tree_index)
|
||||
return tree_indexes
|
Loading…
Reference in New Issue
Block a user