feat(RAG): add BM25 Retriever. (#1578)

This commit is contained in:
Aries-ckt 2024-05-30 10:13:24 +08:00 committed by GitHub
parent 8533b3d390
commit 47b0630e88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 470 additions and 0 deletions

237
dbgpt/rag/assembler/bm25.py Normal file
View File

@ -0,0 +1,237 @@
"""BM25 Assembler."""
import json
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import Any, List, Optional
from dbgpt.core import Chunk
from ...storage.vector_store.elastic_store import ElasticsearchVectorConfig
from ...util.executor_utils import blocking_func_to_async
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..knowledge.base import Knowledge
from ..retriever.bm25 import BM25Retriever
class BM25Assembler(BaseAssembler):
"""BM25 Assembler.
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/index-
modules-similarity.html
TF/IDF based similarity that has built-in tf normalization and is supposed to
work better for short fields (like names). See Okapi_BM25 for more details.
This similarity has the following options:
Example:
.. code-block:: python
from dbgpt.rag.assembler import BM25Assembler
pdf_path = "path/to/document.pdf"
knowledge = KnowledgeFactory.from_file_path(pdf_path)
assembler = BM25Assembler.load_from_knowledge(
knowledge=knowledge,
es_config=es_config,
chunk_parameters=chunk_parameters,
)
assembler.persist()
# get bm25 retriever
retriever = assembler.as_retriever(3)
chunks = retriever.retrieve_with_scores("what is awel talk about", 0.3)
print(f"bm25 rag example results:{chunks}")
"""
def __init__(
self,
knowledge: Knowledge,
es_config: ElasticsearchVectorConfig = None,
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
chunk_parameters: Optional[ChunkParameters] = None,
executor: Optional[Executor] = None,
**kwargs: Any,
) -> None:
"""Initialize with BM25 Assembler arguments.
Args:
knowledge: (Knowledge) Knowledge datasource.
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
k1 (Optional[float]): Controls non-linear term frequency normalization
(saturation). The default value is 2.0.
b (Optional[float]): Controls to what degree document length normalizes
tf values. The default value is 0.75.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
"""
from elasticsearch import Elasticsearch
self._es_config = es_config
self._es_url = es_config.uri
self._es_port = es_config.port
self._es_username = es_config.user
self._es_password = es_config.password
self._index_name = es_config.name
self._k1 = k1
self._b = b
if self._es_username and self._es_password:
self._es_client = Elasticsearch(
hosts=[f"http://{self._es_url}:{self._es_port}"],
basic_auth=(self._es_username, self._es_password),
)
else:
self._es_client = Elasticsearch(
hosts=[f"http://{self._es_url}:{self._es_port}"],
)
self._es_index_settings = {
"analysis": {"analyzer": {"default": {"type": "standard"}}},
"similarity": {
"custom_bm25": {
"type": "BM25",
"k1": k1,
"b": b,
}
},
}
self._es_mappings = {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25",
},
"metadata": {
"type": "keyword",
},
}
}
self._executor = executor or ThreadPoolExecutor()
if knowledge is None:
raise ValueError("knowledge datasource must be provided.")
if not self._es_client.indices.exists(index=self._index_name):
self._es_client.indices.create(
index=self._index_name,
mappings=self._es_mappings,
settings=self._es_index_settings,
)
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
**kwargs,
)
@classmethod
def load_from_knowledge(
cls,
knowledge: Knowledge,
es_config: ElasticsearchVectorConfig = None,
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
chunk_parameters: Optional[ChunkParameters] = None,
) -> "BM25Assembler":
"""Load document full text into elasticsearch from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
k1: (Optional[float]) BM25 parameter k1.
b: (Optional[float]) BM25 parameter b.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
Returns:
BM25Assembler
"""
return cls(
knowledge=knowledge,
es_config=es_config,
k1=k1,
b=b,
chunk_parameters=chunk_parameters,
)
@classmethod
async def aload_from_knowledge(
cls,
knowledge: Knowledge,
es_config: ElasticsearchVectorConfig = None,
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
chunk_parameters: Optional[ChunkParameters] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> "BM25Assembler":
"""Load document full text into elasticsearch from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
k1: (Optional[float]) BM25 parameter k1.
b: (Optional[float]) BM25 parameter b.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
executor: (Optional[ThreadPoolExecutor]) executor.
Returns:
BM25Assembler
"""
return await blocking_func_to_async(
executor,
cls,
knowledge,
es_config=es_config,
k1=k1,
b=b,
chunk_parameters=chunk_parameters,
)
def persist(self) -> List[str]:
"""Persist chunks into elasticsearch.
Returns:
List[str]: List of chunk ids.
"""
try:
from elasticsearch.helpers import bulk
except ImportError:
raise ValueError("Please install package `pip install elasticsearch`.")
es_requests = []
ids = []
contents = [chunk.content for chunk in self._chunks]
metadatas = [json.dumps(chunk.metadata) for chunk in self._chunks]
chunk_ids = [chunk.chunk_id for chunk in self._chunks]
for i, content in enumerate(contents):
es_request = {
"_op_type": "index",
"_index": self._index_name,
"content": content,
"metadata": metadatas[i],
"_id": chunk_ids[i],
}
ids.append(chunk_ids[i])
es_requests.append(es_request)
bulk(self._es_client, es_requests)
self._es_client.indices.refresh(index=self._index_name)
return ids
async def apersist(self) -> List[str]:
"""Persist chunks into elasticsearch.
Returns:
List[str]: List of chunk ids.
"""
return await blocking_func_to_async(self._executor, self.persist)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
return []
def as_retriever(self, top_k: int = 4, **kwargs) -> BM25Retriever:
"""Create a BM25Retriever.
Args:
top_k(int): default 4.
Returns:
BM25Retriever
"""
return BM25Retriever(
top_k=top_k, es_index=self._index_name, es_client=self._es_client
)

183
dbgpt/rag/retriever/bm25.py Normal file
View File

@ -0,0 +1,183 @@
"""BM25 retriever."""
import json
from concurrent.futures import Executor, ThreadPoolExecutor
from typing import Any, List, Optional
from dbgpt.app.base import logger
from dbgpt.core import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import blocking_func_to_async
class BM25Retriever(BaseRetriever):
"""BM25 retriever.
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/index-
modules-similarity.html
TF/IDF based similarity that has built-in tf normalization and is supposed to
work better for short fields (like names). See Okapi_BM25 for more details.
This similarity has the following options:"""
def __init__(
self,
top_k: int = 4,
es_index: str = None,
es_client: Any = None,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
k1: Optional[float] = 2.0,
b: Optional[float] = 0.75,
executor: Optional[Executor] = None,
):
"""Create BM25Retriever.
Args:
top_k (int): top k
es_index (str): elasticsearch index
es_client (Any): elasticsearch client
query_rewrite (Optional[QueryRewrite]): query rewrite
rerank (Ranker): rerank
k1 (Optional[float]): Controls non-linear term frequency normalization
(saturation). The default value is 2.0.
b (Optional[float]): Controls to what degree document length normalizes
tf values. The default value is 0.75.
executor (Optional[Executor]): executor
Returns:
BM25Retriever: BM25 retriever
"""
super().__init__()
self._top_k = top_k
self._query_rewrite = query_rewrite
try:
from elasticsearch import Elasticsearch
except ImportError:
raise ImportError(
"please install elasticsearch using `pip install elasticsearch`"
)
self._es_client: Elasticsearch = es_client
self._es_mappings = {
"properties": {
"content": {
"type": "text",
"similarity": "custom_bm25",
}
}
}
self._es_index_settings = {
"analysis": {"analyzer": {"default": {"type": "standard"}}},
"similarity": {
"custom_bm25": {
"type": "BM25",
"k1": k1,
"b": b,
}
},
}
self._index_name = es_index
if not self._es_client.indices.exists(index=self._index_name):
self._es_client.indices.create(
index=self._index_name,
mappings=self._es_mappings,
settings=self._es_index_settings,
)
self._rerank = rerank or DefaultRanker(self._top_k)
self._executor = executor or ThreadPoolExecutor()
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
es_query = {"query": {"match": {"content": query}}}
res = self._es_client.search(index=self._index_name, body=es_query)
chunks = []
for r in res["hits"]["hits"]:
chunks.append(
Chunk(
chunk_id=r["_id"],
content=r["_source"]["content"],
metadata=json.loads(r["_source"]["metadata"]),
)
)
return chunks[: self._top_k]
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
es_query = {"query": {"match": {"content": query}}}
res = self._es_client.search(index=self._index_name, body=es_query)
chunks_with_scores = []
for r in res["hits"]["hits"]:
if r["_score"] >= score_threshold:
chunks_with_scores.append(
Chunk(
chunk_id=r["_id"],
content=r["_source"]["content"],
metadata=json.loads(r["_source"]["metadata"]),
score=r["_score"],
)
)
if score_threshold is not None and len(chunks_with_scores) == 0:
logger.warning(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return chunks_with_scores[: self._top_k]
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text.
filters: metadata filters.
Return:
List[Chunk]: list of chunks
"""
return await blocking_func_to_async(
self._executor, self.retrieve, query, filters
)
async def _aretrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
return await blocking_func_to_async(
self._executor, self.retrieve, query, filters
)

View File

@ -0,0 +1,50 @@
import asyncio
import os
from dbgpt.configs.model_config import ROOT_PATH
from dbgpt.rag import ChunkParameters
from dbgpt.rag.assembler.bm25 import BM25Assembler
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.storage.vector_store.elastic_store import ElasticsearchVectorConfig
"""Embedding rag example.
pre-requirements:
set your elasticsearch config in your example code.
Examples:
..code-block:: shell
python examples/rag/bm25_retriever_example.py
"""
def _create_es_config():
"""Create vector connector."""
return ElasticsearchVectorConfig(
name="bm25_es_dbgpt",
uri="localhost",
port="9200",
user="elastic",
password="dbgpt",
)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
es_config = _create_es_config()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
# create bm25 assembler
assembler = BM25Assembler.load_from_knowledge(
knowledge=knowledge,
es_config=es_config,
chunk_parameters=chunk_parameters,
)
assembler.persist()
# get bm25 retriever
retriever = assembler.as_retriever(3)
chunks = retriever.retrieve_with_scores("what is awel talk about", 0.3)
print(f"bm25 rag example results:{chunks}")
if __name__ == "__main__":
asyncio.run(main())