mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 04:44:14 +00:00
feat(RAG): add BM25 Retriever. (#1578)
This commit is contained in:
parent
8533b3d390
commit
47b0630e88
237
dbgpt/rag/assembler/bm25.py
Normal file
237
dbgpt/rag/assembler/bm25.py
Normal file
@ -0,0 +1,237 @@
|
||||
"""BM25 Assembler."""
|
||||
import json
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
|
||||
from ...storage.vector_store.elastic_store import ElasticsearchVectorConfig
|
||||
from ...util.executor_utils import blocking_func_to_async
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..knowledge.base import Knowledge
|
||||
from ..retriever.bm25 import BM25Retriever
|
||||
|
||||
|
||||
class BM25Assembler(BaseAssembler):
|
||||
"""BM25 Assembler.
|
||||
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/index-
|
||||
modules-similarity.html
|
||||
TF/IDF based similarity that has built-in tf normalization and is supposed to
|
||||
work better for short fields (like names). See Okapi_BM25 for more details.
|
||||
This similarity has the following options:
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.rag.assembler import BM25Assembler
|
||||
|
||||
pdf_path = "path/to/document.pdf"
|
||||
knowledge = KnowledgeFactory.from_file_path(pdf_path)
|
||||
assembler = BM25Assembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
es_config=es_config,
|
||||
chunk_parameters=chunk_parameters,
|
||||
)
|
||||
assembler.persist()
|
||||
# get bm25 retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
chunks = retriever.retrieve_with_scores("what is awel talk about", 0.3)
|
||||
print(f"bm25 rag example results:{chunks}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge: Knowledge,
|
||||
es_config: ElasticsearchVectorConfig = None,
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
executor: Optional[Executor] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with BM25 Assembler arguments.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
|
||||
k1 (Optional[float]): Controls non-linear term frequency normalization
|
||||
(saturation). The default value is 2.0.
|
||||
b (Optional[float]): Controls to what degree document length normalizes
|
||||
tf values. The default value is 0.75.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
"""
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
self._es_config = es_config
|
||||
self._es_url = es_config.uri
|
||||
self._es_port = es_config.port
|
||||
self._es_username = es_config.user
|
||||
self._es_password = es_config.password
|
||||
self._index_name = es_config.name
|
||||
self._k1 = k1
|
||||
self._b = b
|
||||
if self._es_username and self._es_password:
|
||||
self._es_client = Elasticsearch(
|
||||
hosts=[f"http://{self._es_url}:{self._es_port}"],
|
||||
basic_auth=(self._es_username, self._es_password),
|
||||
)
|
||||
else:
|
||||
self._es_client = Elasticsearch(
|
||||
hosts=[f"http://{self._es_url}:{self._es_port}"],
|
||||
)
|
||||
self._es_index_settings = {
|
||||
"analysis": {"analyzer": {"default": {"type": "standard"}}},
|
||||
"similarity": {
|
||||
"custom_bm25": {
|
||||
"type": "BM25",
|
||||
"k1": k1,
|
||||
"b": b,
|
||||
}
|
||||
},
|
||||
}
|
||||
self._es_mappings = {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "text",
|
||||
"similarity": "custom_bm25",
|
||||
},
|
||||
"metadata": {
|
||||
"type": "keyword",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
if knowledge is None:
|
||||
raise ValueError("knowledge datasource must be provided.")
|
||||
if not self._es_client.indices.exists(index=self._index_name):
|
||||
self._es_client.indices.create(
|
||||
index=self._index_name,
|
||||
mappings=self._es_mappings,
|
||||
settings=self._es_index_settings,
|
||||
)
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
es_config: ElasticsearchVectorConfig = None,
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
) -> "BM25Assembler":
|
||||
"""Load document full text into elasticsearch from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
|
||||
k1: (Optional[float]) BM25 parameter k1.
|
||||
b: (Optional[float]) BM25 parameter b.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
|
||||
Returns:
|
||||
BM25Assembler
|
||||
"""
|
||||
return cls(
|
||||
knowledge=knowledge,
|
||||
es_config=es_config,
|
||||
k1=k1,
|
||||
b=b,
|
||||
chunk_parameters=chunk_parameters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def aload_from_knowledge(
|
||||
cls,
|
||||
knowledge: Knowledge,
|
||||
es_config: ElasticsearchVectorConfig = None,
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
executor: Optional[ThreadPoolExecutor] = None,
|
||||
) -> "BM25Assembler":
|
||||
"""Load document full text into elasticsearch from path.
|
||||
|
||||
Args:
|
||||
knowledge: (Knowledge) Knowledge datasource.
|
||||
es_config: (ElasticsearchVectorConfig) Elasticsearch config.
|
||||
k1: (Optional[float]) BM25 parameter k1.
|
||||
b: (Optional[float]) BM25 parameter b.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
executor: (Optional[ThreadPoolExecutor]) executor.
|
||||
|
||||
Returns:
|
||||
BM25Assembler
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
executor,
|
||||
cls,
|
||||
knowledge,
|
||||
es_config=es_config,
|
||||
k1=k1,
|
||||
b=b,
|
||||
chunk_parameters=chunk_parameters,
|
||||
)
|
||||
|
||||
def persist(self) -> List[str]:
|
||||
"""Persist chunks into elasticsearch.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
try:
|
||||
from elasticsearch.helpers import bulk
|
||||
except ImportError:
|
||||
raise ValueError("Please install package `pip install elasticsearch`.")
|
||||
es_requests = []
|
||||
ids = []
|
||||
contents = [chunk.content for chunk in self._chunks]
|
||||
metadatas = [json.dumps(chunk.metadata) for chunk in self._chunks]
|
||||
chunk_ids = [chunk.chunk_id for chunk in self._chunks]
|
||||
for i, content in enumerate(contents):
|
||||
es_request = {
|
||||
"_op_type": "index",
|
||||
"_index": self._index_name,
|
||||
"content": content,
|
||||
"metadata": metadatas[i],
|
||||
"_id": chunk_ids[i],
|
||||
}
|
||||
ids.append(chunk_ids[i])
|
||||
es_requests.append(es_request)
|
||||
bulk(self._es_client, es_requests)
|
||||
self._es_client.indices.refresh(index=self._index_name)
|
||||
return ids
|
||||
|
||||
async def apersist(self) -> List[str]:
|
||||
"""Persist chunks into elasticsearch.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return await blocking_func_to_async(self._executor, self.persist)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
||||
def as_retriever(self, top_k: int = 4, **kwargs) -> BM25Retriever:
|
||||
"""Create a BM25Retriever.
|
||||
|
||||
Args:
|
||||
top_k(int): default 4.
|
||||
|
||||
Returns:
|
||||
BM25Retriever
|
||||
"""
|
||||
return BM25Retriever(
|
||||
top_k=top_k, es_index=self._index_name, es_client=self._es_client
|
||||
)
|
183
dbgpt/rag/retriever/bm25.py
Normal file
183
dbgpt/rag/retriever/bm25.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""BM25 retriever."""
|
||||
import json
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.app.base import logger
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
|
||||
|
||||
class BM25Retriever(BaseRetriever):
|
||||
"""BM25 retriever.
|
||||
refer https://www.elastic.co/guide/en/elasticsearch/reference/8.9/index-
|
||||
modules-similarity.html
|
||||
TF/IDF based similarity that has built-in tf normalization and is supposed to
|
||||
work better for short fields (like names). See Okapi_BM25 for more details.
|
||||
This similarity has the following options:"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = 4,
|
||||
es_index: str = None,
|
||||
es_client: Any = None,
|
||||
query_rewrite: Optional[QueryRewrite] = None,
|
||||
rerank: Optional[Ranker] = None,
|
||||
k1: Optional[float] = 2.0,
|
||||
b: Optional[float] = 0.75,
|
||||
executor: Optional[Executor] = None,
|
||||
):
|
||||
"""Create BM25Retriever.
|
||||
|
||||
Args:
|
||||
top_k (int): top k
|
||||
es_index (str): elasticsearch index
|
||||
es_client (Any): elasticsearch client
|
||||
query_rewrite (Optional[QueryRewrite]): query rewrite
|
||||
rerank (Ranker): rerank
|
||||
k1 (Optional[float]): Controls non-linear term frequency normalization
|
||||
(saturation). The default value is 2.0.
|
||||
b (Optional[float]): Controls to what degree document length normalizes
|
||||
tf values. The default value is 0.75.
|
||||
executor (Optional[Executor]): executor
|
||||
|
||||
Returns:
|
||||
BM25Retriever: BM25 retriever
|
||||
"""
|
||||
super().__init__()
|
||||
self._top_k = top_k
|
||||
self._query_rewrite = query_rewrite
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"please install elasticsearch using `pip install elasticsearch`"
|
||||
)
|
||||
self._es_client: Elasticsearch = es_client
|
||||
|
||||
self._es_mappings = {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "text",
|
||||
"similarity": "custom_bm25",
|
||||
}
|
||||
}
|
||||
}
|
||||
self._es_index_settings = {
|
||||
"analysis": {"analyzer": {"default": {"type": "standard"}}},
|
||||
"similarity": {
|
||||
"custom_bm25": {
|
||||
"type": "BM25",
|
||||
"k1": k1,
|
||||
"b": b,
|
||||
}
|
||||
},
|
||||
}
|
||||
self._index_name = es_index
|
||||
if not self._es_client.indices.exists(index=self._index_name):
|
||||
self._es_client.indices.create(
|
||||
index=self._index_name,
|
||||
mappings=self._es_mappings,
|
||||
settings=self._es_index_settings,
|
||||
)
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
self._executor = executor or ThreadPoolExecutor()
|
||||
|
||||
def _retrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
es_query = {"query": {"match": {"content": query}}}
|
||||
res = self._es_client.search(index=self._index_name, body=es_query)
|
||||
|
||||
chunks = []
|
||||
for r in res["hits"]["hits"]:
|
||||
chunks.append(
|
||||
Chunk(
|
||||
chunk_id=r["_id"],
|
||||
content=r["_source"]["content"],
|
||||
metadata=json.loads(r["_source"]["metadata"]),
|
||||
)
|
||||
)
|
||||
return chunks[: self._top_k]
|
||||
|
||||
def _retrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
es_query = {"query": {"match": {"content": query}}}
|
||||
res = self._es_client.search(index=self._index_name, body=es_query)
|
||||
|
||||
chunks_with_scores = []
|
||||
for r in res["hits"]["hits"]:
|
||||
if r["_score"] >= score_threshold:
|
||||
chunks_with_scores.append(
|
||||
Chunk(
|
||||
chunk_id=r["_id"],
|
||||
content=r["_source"]["content"],
|
||||
metadata=json.loads(r["_source"]["metadata"]),
|
||||
score=r["_score"],
|
||||
)
|
||||
)
|
||||
if score_threshold is not None and len(chunks_with_scores) == 0:
|
||||
logger.warning(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
return chunks_with_scores[: self._top_k]
|
||||
|
||||
async def _aretrieve(
|
||||
self, query: str, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks.
|
||||
|
||||
Args:
|
||||
query (str): query text.
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
self._executor, self.retrieve, query, filters
|
||||
)
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self,
|
||||
query: str,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Retrieve knowledge chunks with score.
|
||||
|
||||
Args:
|
||||
query (str): query text
|
||||
score_threshold (float): score threshold
|
||||
filters: metadata filters.
|
||||
Return:
|
||||
List[Chunk]: list of chunks with score
|
||||
"""
|
||||
return await blocking_func_to_async(
|
||||
self._executor, self.retrieve, query, filters
|
||||
)
|
50
examples/rag/bm25_retriever_example.py
Normal file
50
examples/rag/bm25_retriever_example.py
Normal file
@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler.bm25 import BM25Assembler
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.storage.vector_store.elastic_store import ElasticsearchVectorConfig
|
||||
|
||||
"""Embedding rag example.
|
||||
pre-requirements:
|
||||
set your elasticsearch config in your example code.
|
||||
|
||||
Examples:
|
||||
..code-block:: shell
|
||||
python examples/rag/bm25_retriever_example.py
|
||||
"""
|
||||
|
||||
|
||||
def _create_es_config():
|
||||
"""Create vector connector."""
|
||||
return ElasticsearchVectorConfig(
|
||||
name="bm25_es_dbgpt",
|
||||
uri="localhost",
|
||||
port="9200",
|
||||
user="elastic",
|
||||
password="dbgpt",
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
es_config = _create_es_config()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
# create bm25 assembler
|
||||
assembler = BM25Assembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
es_config=es_config,
|
||||
chunk_parameters=chunk_parameters,
|
||||
)
|
||||
assembler.persist()
|
||||
# get bm25 retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
chunks = retriever.retrieve_with_scores("what is awel talk about", 0.3)
|
||||
print(f"bm25 rag example results:{chunks}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Loading…
Reference in New Issue
Block a user