From 421004a1d89c3b12b69605ff77a04b49d276d1a7 Mon Sep 17 00:00:00 2001 From: Aries-ckt <916701291@qq.com> Date: Sat, 10 May 2025 13:02:26 +0800 Subject: [PATCH] feat(RAG):add tree retriever based on document directory level (#2669) --- examples/rag/doc_tree_retriever_example.py | 36 ++ .../src/dbgpt/core/interface/knowledge.py | 4 + .../src/dbgpt/rag/knowledge/base.py | 6 +- packages/dbgpt-core/src/dbgpt/storage/base.py | 44 +++ .../src/dbgpt_ext/rag/knowledge/markdown.py | 20 +- .../src/dbgpt_ext/rag/retriever/doc_tree.py | 364 ++++++++++++++++++ 6 files changed, 472 insertions(+), 2 deletions(-) create mode 100644 examples/rag/doc_tree_retriever_example.py create mode 100644 packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/doc_tree.py diff --git a/examples/rag/doc_tree_retriever_example.py b/examples/rag/doc_tree_retriever_example.py new file mode 100644 index 000000000..f985a967a --- /dev/null +++ b/examples/rag/doc_tree_retriever_example.py @@ -0,0 +1,36 @@ +import asyncio + +from dbgpt.model.proxy import DeepseekLLMClient +from dbgpt.rag.knowledge.base import ChunkStrategy +from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor +from dbgpt_ext.rag import ChunkParameters +from dbgpt_ext.rag.knowledge import KnowledgeFactory +from dbgpt_ext.rag.retriever.doc_tree import DocTreeRetriever + + +async def main(): + knowledge = KnowledgeFactory.from_file_path("../../docs/docs/awel/awel.md") + chunk_parameters = ChunkParameters( + chunk_strategy=ChunkStrategy.CHUNK_BY_MARKDOWN_HEADER.name + ) + docs = knowledge.load() + docs = knowledge.extract(docs, chunk_parameters) + llm_client = DeepseekLLMClient(api_key="your_api_key") + keyword_extractor = KeywordExtractor( + llm_client=llm_client, model_name="deepseek-chat" + ) + # doc tree retriever retriever + retriever = DocTreeRetriever( + docs=docs, + top_k=10, + keywords_extractor=keyword_extractor, + with_content=False, + ) + tree_index = retriever._tree_indexes[0] + nodes = await retriever.aretrieve("Introduce awel Operators") + for node in nodes: + tree_index.display_tree(node) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/dbgpt-core/src/dbgpt/core/interface/knowledge.py b/packages/dbgpt-core/src/dbgpt/core/interface/knowledge.py index 5d7a33700..5f3a6b9fa 100644 --- a/packages/dbgpt-core/src/dbgpt/core/interface/knowledge.py +++ b/packages/dbgpt-core/src/dbgpt/core/interface/knowledge.py @@ -16,6 +16,10 @@ class Document(BaseModel): default_factory=dict, description="metadata fields", ) + chunks: List["Chunk"] = Field( + default_factory=list, + description="list of chunks", + ) def set_content(self, content: str) -> None: """Set document content.""" diff --git a/packages/dbgpt-core/src/dbgpt/rag/knowledge/base.py b/packages/dbgpt-core/src/dbgpt/rag/knowledge/base.py index b6fbfa9c2..0b809026f 100644 --- a/packages/dbgpt-core/src/dbgpt/rag/knowledge/base.py +++ b/packages/dbgpt-core/src/dbgpt/rag/knowledge/base.py @@ -164,7 +164,11 @@ class Knowledge(ABC): documents = self._load() return self._postprocess(documents) - def extract(self, documents: List[Document]) -> List[Document]: + def extract( + self, + documents: List[Document], + chunk_parameters: Optional[Any] = None, + ) -> List[Document]: """Extract knowledge from text.""" return documents diff --git a/packages/dbgpt-core/src/dbgpt/storage/base.py b/packages/dbgpt-core/src/dbgpt/storage/base.py index b025d71a7..80cd84ac8 100644 --- a/packages/dbgpt-core/src/dbgpt/storage/base.py +++ b/packages/dbgpt-core/src/dbgpt/storage/base.py @@ -231,3 +231,47 @@ class IndexStoreBase(ABC): return await blocking_func_to_async_no_executor( self.similar_search_with_scores, query, topk, score_threshold, filters ) + + def full_text_search( + self, text: str, topk: int, filters: Optional[MetadataFilters] = None + ) -> List[Chunk]: + """Full text search in index database. + + Args: + text(str): The query text. + topk(int): The number of similar documents to return. + filters(Optional[MetadataFilters]): metadata filters. + Return: + List[Chunk]: The similar documents. + """ + raise NotImplementedError( + "Full text search is not supported in this index store." + ) + + async def afull_text_search( + self, text: str, topk: int, filters: Optional[MetadataFilters] = None + ) -> List[Chunk]: + """Similar search in index database. + + Args: + text(str): The query text. + topk(int): The number of similar documents to return. + filters(Optional[MetadataFilters]): metadata filters. + Return: + List[Chunk]: The similar documents. + """ + return await blocking_func_to_async_no_executor( + self.full_text_search, text, topk, filters + ) + + def is_support_full_text_search(self) -> bool: + """Support full text search. + + Args: + collection_name(str): collection name. + Return: + bool: The similar documents. + """ + raise NotImplementedError( + "Full text search is not supported in this index store." + ) diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/knowledge/markdown.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/knowledge/markdown.py index c76997136..c335d6342 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/knowledge/markdown.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/knowledge/markdown.py @@ -9,6 +9,7 @@ from dbgpt.rag.knowledge.base import ( Knowledge, KnowledgeType, ) +from dbgpt_ext.rag import ChunkParameters class MarkdownKnowledge(Knowledge): @@ -49,13 +50,30 @@ class MarkdownKnowledge(Knowledge): raise ValueError("file path is required") with open(self._path, encoding=self._encoding, errors="ignore") as f: markdown_text = f.read() - metadata = {"source": self._path} + metadata = { + "source": self._path, + "title": self._path.rsplit("/", 1)[-1], + } if self._metadata: metadata.update(self._metadata) # type: ignore documents = [Document(content=markdown_text, metadata=metadata)] return documents return [Document.langchain2doc(lc_document) for lc_document in documents] + def extract( + self, + documents: List[Document], + chunk_parameter: Optional[ChunkParameters] = None, + ) -> List[Document]: + """Extract knowledge from text.""" + from dbgpt_ext.rag.chunk_manager import ChunkManager + + chunk_manager = ChunkManager(knowledge=self, chunk_parameter=chunk_parameter) + chunks = chunk_manager.split(documents) + for document in documents: + document.chunks = chunks + return documents + @classmethod def support_chunk_strategy(cls) -> List[ChunkStrategy]: """Return support chunk strategy.""" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/doc_tree.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/doc_tree.py new file mode 100644 index 000000000..b2ce9cc08 --- /dev/null +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/doc_tree.py @@ -0,0 +1,364 @@ +"""Tree-based document retriever.""" + +import logging +from concurrent.futures import Executor, ThreadPoolExecutor +from typing import List, Optional + +from dbgpt.core import Document +from dbgpt.rag.retriever import BaseRetriever, DefaultRanker, QueryRewrite, Ranker +from dbgpt.rag.transformer.base import ExtractorBase +from dbgpt.storage.vector_store.filters import MetadataFilters + +logger = logging.getLogger(__name__) + +RETRIEVER_NAME = "doc_tree_retriever" +TITLE = "title" +HEADER1 = "Header1" +HEADER2 = "Header2" +HEADER3 = "Header3" +HEADER4 = "Header4" +HEADER5 = "Header5" +HEADER6 = "Header6" + + +class TreeNode: + """TreeNode class to represent a node in the document tree.""" + + def __init__( + self, node_id: str, level_text: str, level: int, content: Optional[str] = None + ): + """Initialize a TreeNode.""" + self.node_id = node_id + self.level = level # 0: title, 1: header1, 2: header2, 3: header3 + self.level_text = level_text # 0: title, 1: header1, 2: header2, 3: header3 + self.children = [] + self.content = content + self.retriever = RETRIEVER_NAME + + def add_child(self, child_node): + """Add a child node to the current node.""" + self.children.append(child_node) + + +class DocTreeIndex: + def __init__(self): + """Initialize the document tree index.""" + self.root = TreeNode("root_id", "Root", -1) + + def add_nodes( + self, + node_id: str, + title: str, + header1: Optional[str] = None, + header2: Optional[str] = None, + header3: Optional[str] = None, + header4: Optional[str] = None, + header5: Optional[str] = None, + header6: Optional[str] = None, + content: Optional[str] = None, + ): + """Add nodes to the document tree. + + Args: + node_id (str): The ID of the node. + title (str): The title of the document. + header1 (Optional[str]): The first header. + header2 (Optional[str]): The second header. + header3 (Optional[str]): The third header. + header4 (Optional[str]): The fourth header. + header5 (Optional[str]): The fifth header. + header6 (Optional[str]): The sixth header. + content (Optional[str]): The content of the node. + """ + # Assuming titles is a dictionary containing title and headers + title_node = None + if title: + title_nodes = self.get_node_by_level(0) + if not title_nodes: + # If title already exists, do not add it again + title_node = TreeNode(node_id, title, 0, content) + self.root.add_child(title_node) + else: + title_node = title_nodes[0] + current_node = title_node + headers = [header1, header2, header3, header4, header5, header6] + for level, header in enumerate(headers, start=1): + if header: + header_nodes = self.get_node_by_level_text(header) + if header_nodes: + # If header already exists, do not add it again + current_node = header_nodes[0] + continue + new_header_node = TreeNode(node_id, header, level, content) + current_node.add_child(new_header_node) + current_node = new_header_node + + def add_nodes_with_content( + self, + node_id: str, + title: str, + header1: Optional[str] = None, + header2: Optional[str] = None, + header3: Optional[str] = None, + header4: Optional[str] = None, + header5: Optional[str] = None, + header6: Optional[str] = None, + ): + """Add nodes to the document tree. + + Args: + node_id (str): The ID of the node. + title (str): The title of the document. + header1 (Optional[str]): The first header. + header2 (Optional[str]): The second header. + header3 (Optional[str]): The third header. + header4 (Optional[str]): The fourth header. + header5 (Optional[str]): The fifth header. + header6 (Optional[str]): The sixth header. + """ + # Assuming titles is a dictionary containing title and headers + title_node = None + if title: + title_nodes = self.get_node_by_level(0) + if not title_nodes: + # If title already exists, do not add it again + title_node = TreeNode(node_id, title, 0) + self.root.add_child(title_node) + else: + title_node = title_nodes[0] + current_node = title_node + headers = [header1, header2, header3, header4, header5, header6] + for level, header in enumerate(headers, start=1): + if header: + header_nodes = self.get_node_by_level_text(header) + if header_nodes: + # If header already exists, do not add it again + current_node = header_nodes[0] + continue + new_header_node = TreeNode(node_id, header, level) + current_node.add_child(new_header_node) + current_node = new_header_node + + def get_node_by_level(self, level): + """Get nodes by level.""" + # Traverse the tree to find nodes at the specified level + result = [] + self._traverse(self.root, level, result) + return result + + def get_node_by_level_text(self, content): + """Get nodes by level.""" + # Traverse the tree to find nodes at the specified level + result = [] + self._traverse_by_level_text(self.root, content, result) + return result + + def get_all_children(self, node): + """get all children of the node.""" + # Get all children of the current node + result = [] + self._traverse(node, node.level, result) + return result + + def display_tree(self, node: TreeNode, prefix: Optional[str] = ""): + """Recursive function to display the directory structure with visual cues.""" + # Print the current node title with prefix + if node.content: + print( + f"{prefix}├── {node.level_text} (node_id: {node.node_id}) " + f"(content: {node.content})" + ) + else: + print(f"{prefix}├── {node.level_text} (node_id: {node.node_id})") + + # Update prefix for children + new_prefix = prefix + "│ " # Extend the prefix for child nodes + for i, child in enumerate(node.children): + if i == len(node.children) - 1: # If it's the last child + new_prefix_child = prefix + "└── " + else: + new_prefix_child = new_prefix + + # Recursive call for the next child node + self.display_tree(child, new_prefix_child) + + def _traverse(self, node, level, result): + """Traverse the tree to find nodes at the specified level.""" + # If the current node's level matches the specified level, add it to the result + if node.level == level: + result.append(node) + # Recursively traverse child nodes + for child in node.children: + self._traverse(child, level, result) + + def _traverse_by_level_text(self, node, level_text, result): + """Traverse the tree to find nodes at the specified level.""" + # If the current node's level matches the specified level, add it to the result + if node.level_text == level_text: + result.append(node) + # Recursively traverse child nodes + for child in node.children: + self._traverse_by_level_text(child, level_text, result) + + def search_keywords(self, node, keyword) -> Optional[TreeNode]: + # Check if the keyword matches the current node title + if keyword.lower() == node.level_text.lower(): + logger.info(f"DocTreeIndex Match found in: {node.level_text}") + return node + # Recursively search in child nodes + for child in node.children: + result = self.search_keywords(child, keyword) + if result: + return result + # Check if the keyword matches any of the child nodes + # If no match, continue to search in all children + return None + + +class DocTreeRetriever(BaseRetriever): + """Doc Tree retriever.""" + + def __init__( + self, + docs: List[Document] = None, + top_k: Optional[int] = 10, + query_rewrite: Optional[QueryRewrite] = None, + rerank: Optional[Ranker] = None, + keywords_extractor: Optional[ExtractorBase] = None, + with_content: bool = False, + executor: Optional[Executor] = None, + ): + """Create DocTreeRetriever. + + Args: + docs (List[Document]): List of documents to initialize the tree with. + top_k (int): top k + query_rewrite (Optional[QueryRewrite]): query rewrite + rerank (Ranker): rerank + keywords_extractor (Optional[ExtractorBase]): keywords extractor + with_content: bool: whether to include content + executor (Optional[Executor]): executor + + Returns: + DocTreeRetriever: BM25 retriever + """ + super().__init__() + self._top_k = top_k + self._query_rewrite = query_rewrite + self._rerank = rerank or DefaultRanker(self._top_k) + self._keywords_extractor = keywords_extractor + self._with_content = with_content + self._tree_indexes = self._initialize_doc_tree(docs) + self._executor = executor or ThreadPoolExecutor() + + def get_tree_indexes(self): + """Get the tree indexes.""" + return self._tree_indexes + + def _retrieve( + self, query: str, filters: Optional[MetadataFilters] = None + ) -> List[TreeNode]: + """Retrieve knowledge chunks. + + Args: + query (str): query text + filters: metadata filters. + Return: + List[Chunk]: list of chunks + """ + raise NotImplementedError("DocTreeRetriever does not support retrieval.") + + def _retrieve_with_score( + self, + query: str, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + ) -> List[TreeNode]: + """Retrieve knowledge chunks with score. + + Args: + query (str): query text + score_threshold (float): score threshold + filters: metadata filters. + Return: + List[Chunk]: list of chunks with score + """ + raise NotImplementedError("DocTreeRetriever does not support score retrieval.") + + async def _aretrieve( + self, query: str, filters: Optional[MetadataFilters] = None + ) -> List[TreeNode]: + """Retrieve knowledge chunks. + + Args: + query (str): query text. + filters: metadata filters. + Return: + List[Chunk]: list of chunks + """ + keywords = [query] + if self._keywords_extractor: + keywords = await self._keywords_extractor.extract(query) + all_nodes = [] + for keyword in keywords: + for tree_index in self._tree_indexes: + retrieve_node = tree_index.search_keywords(tree_index.root, keyword) + if retrieve_node: + # If a match is found, return the corresponding chunks + all_nodes.append(retrieve_node) + return all_nodes + + async def _aretrieve_with_score( + self, + query: str, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + ) -> List[TreeNode]: + """Retrieve knowledge chunks with score. + + Args: + query (str): query text + score_threshold (float): score threshold + filters: metadata filters. + Return: + List[Chunk]: list of chunks with score + """ + return await self._aretrieve(query, filters) + + def _initialize_doc_tree(self, docs: List[Document]): + """Initialize the document tree with docs. + + Args: + docs (List[Document]): List of docs to initialize the tree with. + """ + tree_indexes = [] + for doc in docs: + tree_index = DocTreeIndex() + for chunk in doc.chunks: + if not chunk.metadata.get(TITLE): + continue + if not self._with_content: + tree_index.add_nodes( + node_id=chunk.chunk_id, + title=chunk.metadata[TITLE], + header1=chunk.metadata.get(HEADER1), + header2=chunk.metadata.get(HEADER2), + header3=chunk.metadata.get(HEADER3), + header4=chunk.metadata.get(HEADER4), + header5=chunk.metadata.get(HEADER5), + header6=chunk.metadata.get(HEADER6), + ) + else: + tree_index.add_nodes( + node_id=chunk.chunk_id, + title=chunk.metadata[TITLE], + header1=chunk.metadata.get(HEADER1), + header2=chunk.metadata.get(HEADER2), + header3=chunk.metadata.get(HEADER3), + header4=chunk.metadata.get(HEADER4), + header5=chunk.metadata.get(HEADER5), + header6=chunk.metadata.get(HEADER6), + content=chunk.content, + ) + tree_indexes.append(tree_index) + return tree_indexes