From 47b0630e88dc4dcfd5ef7f0c56a2bdd6ba32e758 Mon Sep 17 00:00:00 2001 From: Aries-ckt <916701291@qq.com> Date: Thu, 30 May 2024 10:13:24 +0800 Subject: [PATCH] feat(RAG): add BM25 Retriever. (#1578) --- dbgpt/rag/assembler/bm25.py | 237 +++++++++++++++++++++++++ dbgpt/rag/retriever/bm25.py | 183 +++++++++++++++++++ examples/rag/bm25_retriever_example.py | 50 ++++++ 3 files changed, 470 insertions(+) create mode 100644 dbgpt/rag/assembler/bm25.py create mode 100644 dbgpt/rag/retriever/bm25.py create mode 100644 examples/rag/bm25_retriever_example.py diff --git a/dbgpt/rag/assembler/bm25.py b/dbgpt/rag/assembler/bm25.py new file mode 100644 index 000000000..b20da1375 --- /dev/null +++ b/dbgpt/rag/assembler/bm25.py @@ -0,0 +1,237 @@ +"""BM25 Assembler.""" +import json +from concurrent.futures import Executor, ThreadPoolExecutor +from typing import Any, List, Optional + +from dbgpt.core import Chunk + +from ...storage.vector_store.elastic_store import ElasticsearchVectorConfig +from ...util.executor_utils import blocking_func_to_async +from ..assembler.base import BaseAssembler +from ..chunk_manager import ChunkParameters +from ..knowledge.base import Knowledge +from ..retriever.bm25 import BM25Retriever + + +class BM25Assembler(BaseAssembler): + """BM25 Assembler. + refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/index- + modules-similarity.html + TF/IDF based similarity that has built-in tf normalization and is supposed to + work better for short fields (like names). See Okapi_BM25 for more details. + This similarity has the following options: + + Example: + .. code-block:: python + + from dbgpt.rag.assembler import BM25Assembler + + pdf_path = "path/to/document.pdf" + knowledge = KnowledgeFactory.from_file_path(pdf_path) + assembler = BM25Assembler.load_from_knowledge( + knowledge=knowledge, + es_config=es_config, + chunk_parameters=chunk_parameters, + ) + assembler.persist() + # get bm25 retriever + retriever = assembler.as_retriever(3) + chunks = retriever.retrieve_with_scores("what is awel talk about", 0.3) + print(f"bm25 rag example results:{chunks}") + """ + + def __init__( + self, + knowledge: Knowledge, + es_config: ElasticsearchVectorConfig = None, + k1: Optional[float] = 2.0, + b: Optional[float] = 0.75, + chunk_parameters: Optional[ChunkParameters] = None, + executor: Optional[Executor] = None, + **kwargs: Any, + ) -> None: + """Initialize with BM25 Assembler arguments. + + Args: + knowledge: (Knowledge) Knowledge datasource. + es_config: (ElasticsearchVectorConfig) Elasticsearch config. + k1 (Optional[float]): Controls non-linear term frequency normalization + (saturation). The default value is 2.0. + b (Optional[float]): Controls to what degree document length normalizes + tf values. The default value is 0.75. + chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for + chunking. + """ + from elasticsearch import Elasticsearch + + self._es_config = es_config + self._es_url = es_config.uri + self._es_port = es_config.port + self._es_username = es_config.user + self._es_password = es_config.password + self._index_name = es_config.name + self._k1 = k1 + self._b = b + if self._es_username and self._es_password: + self._es_client = Elasticsearch( + hosts=[f"http://{self._es_url}:{self._es_port}"], + basic_auth=(self._es_username, self._es_password), + ) + else: + self._es_client = Elasticsearch( + hosts=[f"http://{self._es_url}:{self._es_port}"], + ) + self._es_index_settings = { + "analysis": {"analyzer": {"default": {"type": "standard"}}}, + "similarity": { + "custom_bm25": { + "type": "BM25", + "k1": k1, + "b": b, + } + }, + } + self._es_mappings = { + "properties": { + "content": { + "type": "text", + "similarity": "custom_bm25", + }, + "metadata": { + "type": "keyword", + }, + } + } + + self._executor = executor or ThreadPoolExecutor() + if knowledge is None: + raise ValueError("knowledge datasource must be provided.") + if not self._es_client.indices.exists(index=self._index_name): + self._es_client.indices.create( + index=self._index_name, + mappings=self._es_mappings, + settings=self._es_index_settings, + ) + super().__init__( + knowledge=knowledge, + chunk_parameters=chunk_parameters, + **kwargs, + ) + + @classmethod + def load_from_knowledge( + cls, + knowledge: Knowledge, + es_config: ElasticsearchVectorConfig = None, + k1: Optional[float] = 2.0, + b: Optional[float] = 0.75, + chunk_parameters: Optional[ChunkParameters] = None, + ) -> "BM25Assembler": + """Load document full text into elasticsearch from path. + + Args: + knowledge: (Knowledge) Knowledge datasource. + es_config: (ElasticsearchVectorConfig) Elasticsearch config. + k1: (Optional[float]) BM25 parameter k1. + b: (Optional[float]) BM25 parameter b. + chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for + chunking. + + Returns: + BM25Assembler + """ + return cls( + knowledge=knowledge, + es_config=es_config, + k1=k1, + b=b, + chunk_parameters=chunk_parameters, + ) + + @classmethod + async def aload_from_knowledge( + cls, + knowledge: Knowledge, + es_config: ElasticsearchVectorConfig = None, + k1: Optional[float] = 2.0, + b: Optional[float] = 0.75, + chunk_parameters: Optional[ChunkParameters] = None, + executor: Optional[ThreadPoolExecutor] = None, + ) -> "BM25Assembler": + """Load document full text into elasticsearch from path. + + Args: + knowledge: (Knowledge) Knowledge datasource. + es_config: (ElasticsearchVectorConfig) Elasticsearch config. + k1: (Optional[float]) BM25 parameter k1. + b: (Optional[float]) BM25 parameter b. + chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for + chunking. + executor: (Optional[ThreadPoolExecutor]) executor. + + Returns: + BM25Assembler + """ + return await blocking_func_to_async( + executor, + cls, + knowledge, + es_config=es_config, + k1=k1, + b=b, + chunk_parameters=chunk_parameters, + ) + + def persist(self) -> List[str]: + """Persist chunks into elasticsearch. + + Returns: + List[str]: List of chunk ids. + """ + try: + from elasticsearch.helpers import bulk + except ImportError: + raise ValueError("Please install package `pip install elasticsearch`.") + es_requests = [] + ids = [] + contents = [chunk.content for chunk in self._chunks] + metadatas = [json.dumps(chunk.metadata) for chunk in self._chunks] + chunk_ids = [chunk.chunk_id for chunk in self._chunks] + for i, content in enumerate(contents): + es_request = { + "_op_type": "index", + "_index": self._index_name, + "content": content, + "metadata": metadatas[i], + "_id": chunk_ids[i], + } + ids.append(chunk_ids[i]) + es_requests.append(es_request) + bulk(self._es_client, es_requests) + self._es_client.indices.refresh(index=self._index_name) + return ids + + async def apersist(self) -> List[str]: + """Persist chunks into elasticsearch. + + Returns: + List[str]: List of chunk ids. + """ + return await blocking_func_to_async(self._executor, self.persist) + + def _extract_info(self, chunks) -> List[Chunk]: + """Extract info from chunks.""" + return [] + + def as_retriever(self, top_k: int = 4, **kwargs) -> BM25Retriever: + """Create a BM25Retriever. + + Args: + top_k(int): default 4. + + Returns: + BM25Retriever + """ + return BM25Retriever( + top_k=top_k, es_index=self._index_name, es_client=self._es_client + ) diff --git a/dbgpt/rag/retriever/bm25.py b/dbgpt/rag/retriever/bm25.py new file mode 100644 index 000000000..244feb8f6 --- /dev/null +++ b/dbgpt/rag/retriever/bm25.py @@ -0,0 +1,183 @@ +"""BM25 retriever.""" +import json +from concurrent.futures import Executor, ThreadPoolExecutor +from typing import Any, List, Optional + +from dbgpt.app.base import logger +from dbgpt.core import Chunk +from dbgpt.rag.retriever.base import BaseRetriever +from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker +from dbgpt.rag.retriever.rewrite import QueryRewrite +from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util.executor_utils import blocking_func_to_async + + +class BM25Retriever(BaseRetriever): + """BM25 retriever. + refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/index- + modules-similarity.html + TF/IDF based similarity that has built-in tf normalization and is supposed to + work better for short fields (like names). See Okapi_BM25 for more details. + This similarity has the following options:""" + + def __init__( + self, + top_k: int = 4, + es_index: str = None, + es_client: Any = None, + query_rewrite: Optional[QueryRewrite] = None, + rerank: Optional[Ranker] = None, + k1: Optional[float] = 2.0, + b: Optional[float] = 0.75, + executor: Optional[Executor] = None, + ): + """Create BM25Retriever. + + Args: + top_k (int): top k + es_index (str): elasticsearch index + es_client (Any): elasticsearch client + query_rewrite (Optional[QueryRewrite]): query rewrite + rerank (Ranker): rerank + k1 (Optional[float]): Controls non-linear term frequency normalization + (saturation). The default value is 2.0. + b (Optional[float]): Controls to what degree document length normalizes + tf values. The default value is 0.75. + executor (Optional[Executor]): executor + + Returns: + BM25Retriever: BM25 retriever + """ + super().__init__() + self._top_k = top_k + self._query_rewrite = query_rewrite + try: + from elasticsearch import Elasticsearch + except ImportError: + raise ImportError( + "please install elasticsearch using `pip install elasticsearch`" + ) + self._es_client: Elasticsearch = es_client + + self._es_mappings = { + "properties": { + "content": { + "type": "text", + "similarity": "custom_bm25", + } + } + } + self._es_index_settings = { + "analysis": {"analyzer": {"default": {"type": "standard"}}}, + "similarity": { + "custom_bm25": { + "type": "BM25", + "k1": k1, + "b": b, + } + }, + } + self._index_name = es_index + if not self._es_client.indices.exists(index=self._index_name): + self._es_client.indices.create( + index=self._index_name, + mappings=self._es_mappings, + settings=self._es_index_settings, + ) + self._rerank = rerank or DefaultRanker(self._top_k) + self._executor = executor or ThreadPoolExecutor() + + def _retrieve( + self, query: str, filters: Optional[MetadataFilters] = None + ) -> List[Chunk]: + """Retrieve knowledge chunks. + + Args: + query (str): query text + filters: metadata filters. + Return: + List[Chunk]: list of chunks + """ + es_query = {"query": {"match": {"content": query}}} + res = self._es_client.search(index=self._index_name, body=es_query) + + chunks = [] + for r in res["hits"]["hits"]: + chunks.append( + Chunk( + chunk_id=r["_id"], + content=r["_source"]["content"], + metadata=json.loads(r["_source"]["metadata"]), + ) + ) + return chunks[: self._top_k] + + def _retrieve_with_score( + self, + query: str, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + ) -> List[Chunk]: + """Retrieve knowledge chunks with score. + + Args: + query (str): query text + score_threshold (float): score threshold + filters: metadata filters. + Return: + List[Chunk]: list of chunks with score + """ + es_query = {"query": {"match": {"content": query}}} + res = self._es_client.search(index=self._index_name, body=es_query) + + chunks_with_scores = [] + for r in res["hits"]["hits"]: + if r["_score"] >= score_threshold: + chunks_with_scores.append( + Chunk( + chunk_id=r["_id"], + content=r["_source"]["content"], + metadata=json.loads(r["_source"]["metadata"]), + score=r["_score"], + ) + ) + if score_threshold is not None and len(chunks_with_scores) == 0: + logger.warning( + "No relevant docs were retrieved using the relevance score" + f" threshold {score_threshold}" + ) + return chunks_with_scores[: self._top_k] + + async def _aretrieve( + self, query: str, filters: Optional[MetadataFilters] = None + ) -> List[Chunk]: + """Retrieve knowledge chunks. + + Args: + query (str): query text. + filters: metadata filters. + Return: + List[Chunk]: list of chunks + """ + return await blocking_func_to_async( + self._executor, self.retrieve, query, filters + ) + + async def _aretrieve_with_score( + self, + query: str, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + ) -> List[Chunk]: + """Retrieve knowledge chunks with score. + + Args: + query (str): query text + score_threshold (float): score threshold + filters: metadata filters. + Return: + List[Chunk]: list of chunks with score + """ + return await blocking_func_to_async( + self._executor, self.retrieve, query, filters + ) diff --git a/examples/rag/bm25_retriever_example.py b/examples/rag/bm25_retriever_example.py new file mode 100644 index 000000000..6bc0d780a --- /dev/null +++ b/examples/rag/bm25_retriever_example.py @@ -0,0 +1,50 @@ +import asyncio +import os + +from dbgpt.configs.model_config import ROOT_PATH +from dbgpt.rag import ChunkParameters +from dbgpt.rag.assembler.bm25 import BM25Assembler +from dbgpt.rag.knowledge import KnowledgeFactory +from dbgpt.storage.vector_store.elastic_store import ElasticsearchVectorConfig + +"""Embedding rag example. + pre-requirements: + set your elasticsearch config in your example code. + + Examples: + ..code-block:: shell + python examples/rag/bm25_retriever_example.py +""" + + +def _create_es_config(): + """Create vector connector.""" + return ElasticsearchVectorConfig( + name="bm25_es_dbgpt", + uri="localhost", + port="9200", + user="elastic", + password="dbgpt", + ) + + +async def main(): + file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") + knowledge = KnowledgeFactory.from_file_path(file_path) + es_config = _create_es_config() + chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") + # create bm25 assembler + assembler = BM25Assembler.load_from_knowledge( + knowledge=knowledge, + es_config=es_config, + chunk_parameters=chunk_parameters, + ) + assembler.persist() + # get bm25 retriever + retriever = assembler.as_retriever(3) + chunks = retriever.retrieve_with_scores("what is awel talk about", 0.3) + print(f"bm25 rag example results:{chunks}") + + +if __name__ == "__main__": + asyncio.run(main())