feat(RAG):add tree retriever based on document directory level (#2669)

This commit is contained in:
Aries-ckt 2025-05-10 13:02:26 +08:00 committed by GitHub
parent 3a00aca113
commit 421004a1d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 472 additions and 2 deletions

View File

@ -0,0 +1,36 @@
import asyncio
from dbgpt.model.proxy import DeepseekLLMClient
from dbgpt.rag.knowledge.base import ChunkStrategy
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
from dbgpt_ext.rag import ChunkParameters
from dbgpt_ext.rag.knowledge import KnowledgeFactory
from dbgpt_ext.rag.retriever.doc_tree import DocTreeRetriever
async def main():
knowledge = KnowledgeFactory.from_file_path("../../docs/docs/awel/awel.md")
chunk_parameters = ChunkParameters(
chunk_strategy=ChunkStrategy.CHUNK_BY_MARKDOWN_HEADER.name
)
docs = knowledge.load()
docs = knowledge.extract(docs, chunk_parameters)
llm_client = DeepseekLLMClient(api_key="your_api_key")
keyword_extractor = KeywordExtractor(
llm_client=llm_client, model_name="deepseek-chat"
)
# doc tree retriever retriever
retriever = DocTreeRetriever(
docs=docs,
top_k=10,
keywords_extractor=keyword_extractor,
with_content=False,
)
tree_index = retriever._tree_indexes[0]
nodes = await retriever.aretrieve("Introduce awel Operators")
for node in nodes:
tree_index.display_tree(node)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -16,6 +16,10 @@ class Document(BaseModel):
default_factory=dict,
description="metadata fields",
)
chunks: List["Chunk"] = Field(
default_factory=list,
description="list of chunks",
)
def set_content(self, content: str) -> None:
"""Set document content."""

View File

@ -164,7 +164,11 @@ class Knowledge(ABC):
documents = self._load()
return self._postprocess(documents)
def extract(self, documents: List[Document]) -> List[Document]:
def extract(
self,
documents: List[Document],
chunk_parameters: Optional[Any] = None,
) -> List[Document]:
"""Extract knowledge from text."""
return documents

View File

@ -231,3 +231,47 @@ class IndexStoreBase(ABC):
return await blocking_func_to_async_no_executor(
self.similar_search_with_scores, query, topk, score_threshold, filters
)
def full_text_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Full text search in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
raise NotImplementedError(
"Full text search is not supported in this index store."
)
async def afull_text_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search in index database.
Args:
text(str): The query text.
topk(int): The number of similar documents to return.
filters(Optional[MetadataFilters]): metadata filters.
Return:
List[Chunk]: The similar documents.
"""
return await blocking_func_to_async_no_executor(
self.full_text_search, text, topk, filters
)
def is_support_full_text_search(self) -> bool:
"""Support full text search.
Args:
collection_name(str): collection name.
Return:
bool: The similar documents.
"""
raise NotImplementedError(
"Full text search is not supported in this index store."
)

View File

@ -9,6 +9,7 @@ from dbgpt.rag.knowledge.base import (
Knowledge,
KnowledgeType,
)
from dbgpt_ext.rag import ChunkParameters
class MarkdownKnowledge(Knowledge):
@ -49,13 +50,30 @@ class MarkdownKnowledge(Knowledge):
raise ValueError("file path is required")
with open(self._path, encoding=self._encoding, errors="ignore") as f:
markdown_text = f.read()
metadata = {"source": self._path}
metadata = {
"source": self._path,
"title": self._path.rsplit("/", 1)[-1],
}
if self._metadata:
metadata.update(self._metadata) # type: ignore
documents = [Document(content=markdown_text, metadata=metadata)]
return documents
return [Document.langchain2doc(lc_document) for lc_document in documents]
def extract(
self,
documents: List[Document],
chunk_parameter: Optional[ChunkParameters] = None,
) -> List[Document]:
"""Extract knowledge from text."""
from dbgpt_ext.rag.chunk_manager import ChunkManager
chunk_manager = ChunkManager(knowledge=self, chunk_parameter=chunk_parameter)
chunks = chunk_manager.split(documents)
for document in documents:
document.chunks = chunks
return documents
@classmethod
def support_chunk_strategy(cls) -> List[ChunkStrategy]:
"""Return support chunk strategy."""

View File

@ -0,0 +1,364 @@
"""Tree-based document retriever."""
import logging
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import List, Optional
from dbgpt.core import Document
from dbgpt.rag.retriever import BaseRetriever, DefaultRanker, QueryRewrite, Ranker
from dbgpt.rag.transformer.base import ExtractorBase
from dbgpt.storage.vector_store.filters import MetadataFilters
logger = logging.getLogger(__name__)
RETRIEVER_NAME = "doc_tree_retriever"
TITLE = "title"
HEADER1 = "Header1"
HEADER2 = "Header2"
HEADER3 = "Header3"
HEADER4 = "Header4"
HEADER5 = "Header5"
HEADER6 = "Header6"
class TreeNode:
"""TreeNode class to represent a node in the document tree."""
def __init__(
self, node_id: str, level_text: str, level: int, content: Optional[str] = None
):
"""Initialize a TreeNode."""
self.node_id = node_id
self.level = level # 0: title, 1: header1, 2: header2, 3: header3
self.level_text = level_text # 0: title, 1: header1, 2: header2, 3: header3
self.children = []
self.content = content
self.retriever = RETRIEVER_NAME
def add_child(self, child_node):
"""Add a child node to the current node."""
self.children.append(child_node)
class DocTreeIndex:
def __init__(self):
"""Initialize the document tree index."""
self.root = TreeNode("root_id", "Root", -1)
def add_nodes(
self,
node_id: str,
title: str,
header1: Optional[str] = None,
header2: Optional[str] = None,
header3: Optional[str] = None,
header4: Optional[str] = None,
header5: Optional[str] = None,
header6: Optional[str] = None,
content: Optional[str] = None,
):
"""Add nodes to the document tree.
Args:
node_id (str): The ID of the node.
title (str): The title of the document.
header1 (Optional[str]): The first header.
header2 (Optional[str]): The second header.
header3 (Optional[str]): The third header.
header4 (Optional[str]): The fourth header.
header5 (Optional[str]): The fifth header.
header6 (Optional[str]): The sixth header.
content (Optional[str]): The content of the node.
"""
# Assuming titles is a dictionary containing title and headers
title_node = None
if title:
title_nodes = self.get_node_by_level(0)
if not title_nodes:
# If title already exists, do not add it again
title_node = TreeNode(node_id, title, 0, content)
self.root.add_child(title_node)
else:
title_node = title_nodes[0]
current_node = title_node
headers = [header1, header2, header3, header4, header5, header6]
for level, header in enumerate(headers, start=1):
if header:
header_nodes = self.get_node_by_level_text(header)
if header_nodes:
# If header already exists, do not add it again
current_node = header_nodes[0]
continue
new_header_node = TreeNode(node_id, header, level, content)
current_node.add_child(new_header_node)
current_node = new_header_node
def add_nodes_with_content(
self,
node_id: str,
title: str,
header1: Optional[str] = None,
header2: Optional[str] = None,
header3: Optional[str] = None,
header4: Optional[str] = None,
header5: Optional[str] = None,
header6: Optional[str] = None,
):
"""Add nodes to the document tree.
Args:
node_id (str): The ID of the node.
title (str): The title of the document.
header1 (Optional[str]): The first header.
header2 (Optional[str]): The second header.
header3 (Optional[str]): The third header.
header4 (Optional[str]): The fourth header.
header5 (Optional[str]): The fifth header.
header6 (Optional[str]): The sixth header.
"""
# Assuming titles is a dictionary containing title and headers
title_node = None
if title:
title_nodes = self.get_node_by_level(0)
if not title_nodes:
# If title already exists, do not add it again
title_node = TreeNode(node_id, title, 0)
self.root.add_child(title_node)
else:
title_node = title_nodes[0]
current_node = title_node
headers = [header1, header2, header3, header4, header5, header6]
for level, header in enumerate(headers, start=1):
if header:
header_nodes = self.get_node_by_level_text(header)
if header_nodes:
# If header already exists, do not add it again
current_node = header_nodes[0]
continue
new_header_node = TreeNode(node_id, header, level)
current_node.add_child(new_header_node)
current_node = new_header_node
def get_node_by_level(self, level):
"""Get nodes by level."""
# Traverse the tree to find nodes at the specified level
result = []
self._traverse(self.root, level, result)
return result
def get_node_by_level_text(self, content):
"""Get nodes by level."""
# Traverse the tree to find nodes at the specified level
result = []
self._traverse_by_level_text(self.root, content, result)
return result
def get_all_children(self, node):
"""get all children of the node."""
# Get all children of the current node
result = []
self._traverse(node, node.level, result)
return result
def display_tree(self, node: TreeNode, prefix: Optional[str] = ""):
"""Recursive function to display the directory structure with visual cues."""
# Print the current node title with prefix
if node.content:
print(
f"{prefix}├── {node.level_text} (node_id: {node.node_id}) "
f"(content: {node.content})"
)
else:
print(f"{prefix}├── {node.level_text} (node_id: {node.node_id})")
# Update prefix for children
new_prefix = prefix + "" # Extend the prefix for child nodes
for i, child in enumerate(node.children):
if i == len(node.children) - 1: # If it's the last child
new_prefix_child = prefix + "└── "
else:
new_prefix_child = new_prefix
# Recursive call for the next child node
self.display_tree(child, new_prefix_child)
def _traverse(self, node, level, result):
"""Traverse the tree to find nodes at the specified level."""
# If the current node's level matches the specified level, add it to the result
if node.level == level:
result.append(node)
# Recursively traverse child nodes
for child in node.children:
self._traverse(child, level, result)
def _traverse_by_level_text(self, node, level_text, result):
"""Traverse the tree to find nodes at the specified level."""
# If the current node's level matches the specified level, add it to the result
if node.level_text == level_text:
result.append(node)
# Recursively traverse child nodes
for child in node.children:
self._traverse_by_level_text(child, level_text, result)
def search_keywords(self, node, keyword) -> Optional[TreeNode]:
# Check if the keyword matches the current node title
if keyword.lower() == node.level_text.lower():
logger.info(f"DocTreeIndex Match found in: {node.level_text}")
return node
# Recursively search in child nodes
for child in node.children:
result = self.search_keywords(child, keyword)
if result:
return result
# Check if the keyword matches any of the child nodes
# If no match, continue to search in all children
return None
class DocTreeRetriever(BaseRetriever):
"""Doc Tree retriever."""
def __init__(
self,
docs: List[Document] = None,
top_k: Optional[int] = 10,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
keywords_extractor: Optional[ExtractorBase] = None,
with_content: bool = False,
executor: Optional[Executor] = None,
):
"""Create DocTreeRetriever.
Args:
docs (List[Document]): List of documents to initialize the tree with.
top_k (int): top k
query_rewrite (Optional[QueryRewrite]): query rewrite
rerank (Ranker): rerank
keywords_extractor (Optional[ExtractorBase]): keywords extractor
with_content: bool: whether to include content
executor (Optional[Executor]): executor
Returns:
DocTreeRetriever: BM25 retriever
"""
super().__init__()
self._top_k = top_k
self._query_rewrite = query_rewrite
self._rerank = rerank or DefaultRanker(self._top_k)
self._keywords_extractor = keywords_extractor
self._with_content = with_content
self._tree_indexes = self._initialize_doc_tree(docs)
self._executor = executor or ThreadPoolExecutor()
def get_tree_indexes(self):
"""Get the tree indexes."""
return self._tree_indexes
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[TreeNode]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
raise NotImplementedError("DocTreeRetriever does not support retrieval.")
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[TreeNode]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
raise NotImplementedError("DocTreeRetriever does not support score retrieval.")
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[TreeNode]:
"""Retrieve knowledge chunks.
Args:
query (str): query text.
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
keywords = [query]
if self._keywords_extractor:
keywords = await self._keywords_extractor.extract(query)
all_nodes = []
for keyword in keywords:
for tree_index in self._tree_indexes:
retrieve_node = tree_index.search_keywords(tree_index.root, keyword)
if retrieve_node:
# If a match is found, return the corresponding chunks
all_nodes.append(retrieve_node)
return all_nodes
async def _aretrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[TreeNode]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
return await self._aretrieve(query, filters)
def _initialize_doc_tree(self, docs: List[Document]):
"""Initialize the document tree with docs.
Args:
docs (List[Document]): List of docs to initialize the tree with.
"""
tree_indexes = []
for doc in docs:
tree_index = DocTreeIndex()
for chunk in doc.chunks:
if not chunk.metadata.get(TITLE):
continue
if not self._with_content:
tree_index.add_nodes(
node_id=chunk.chunk_id,
title=chunk.metadata[TITLE],
header1=chunk.metadata.get(HEADER1),
header2=chunk.metadata.get(HEADER2),
header3=chunk.metadata.get(HEADER3),
header4=chunk.metadata.get(HEADER4),
header5=chunk.metadata.get(HEADER5),
header6=chunk.metadata.get(HEADER6),
)
else:
tree_index.add_nodes(
node_id=chunk.chunk_id,
title=chunk.metadata[TITLE],
header1=chunk.metadata.get(HEADER1),
header2=chunk.metadata.get(HEADER2),
header3=chunk.metadata.get(HEADER3),
header4=chunk.metadata.get(HEADER4),
header5=chunk.metadata.get(HEADER5),
header6=chunk.metadata.get(HEADER6),
content=chunk.content,
)
tree_indexes.append(tree_index)
return tree_indexes