mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-31 16:39:48 +00:00
feat(RAG):add cross-encoder rerank (#1442)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -24,7 +24,9 @@ class Ranker(ABC):
|
||||
self.rank_fn = rank_fn
|
||||
|
||||
@abstractmethod
|
||||
def rank(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
def rank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Return top k chunks after ranker.
|
||||
|
||||
Rank algorithm implementation return topk documents by candidates
|
||||
@@ -32,10 +34,9 @@ class Ranker(ABC):
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Tuple]
|
||||
topk: int
|
||||
|
||||
query: Optional[str]
|
||||
Return:
|
||||
List[Document]
|
||||
List[Chunk]
|
||||
"""
|
||||
|
||||
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
@@ -77,11 +78,17 @@ class Ranker(ABC):
|
||||
class DefaultRanker(Ranker):
|
||||
"""Default Ranker."""
|
||||
|
||||
def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None):
|
||||
def __init__(
|
||||
self,
|
||||
topk: int = 4,
|
||||
rank_fn: Optional[RANK_FUNC] = None,
|
||||
):
|
||||
"""Create Default Ranker with topk and rank_fn."""
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
|
||||
def rank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Return top k chunks after ranker.
|
||||
|
||||
Return top k documents by candidates similarity score
|
||||
@@ -105,11 +112,17 @@ class DefaultRanker(Ranker):
|
||||
class RRFRanker(Ranker):
|
||||
"""RRF(Reciprocal Rank Fusion) Ranker."""
|
||||
|
||||
def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None):
|
||||
def __init__(
|
||||
self,
|
||||
topk: int = 4,
|
||||
rank_fn: Optional[RANK_FUNC] = None,
|
||||
):
|
||||
"""RRF rank algorithm implementation."""
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
|
||||
def rank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""RRF rank algorithm implementation.
|
||||
|
||||
This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a
|
||||
@@ -128,3 +141,87 @@ class RRFRanker(Ranker):
|
||||
"""
|
||||
# it will be implemented soon when multi recall is implemented
|
||||
return candidates_with_scores
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("CrossEncoder Rerank"),
|
||||
"cross_encoder_ranker",
|
||||
category=ResourceCategory.RAG,
|
||||
description=_("CrossEncoder ranker."),
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Top k"),
|
||||
"topk",
|
||||
int,
|
||||
description=_("The number of top k documents."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Rerank Model"),
|
||||
"model",
|
||||
str,
|
||||
description=_("rerank model name, e.g., 'BAAI/bge-reranker-base'."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("device"),
|
||||
"device",
|
||||
str,
|
||||
description=_("device name, e.g., 'cpu'."),
|
||||
),
|
||||
],
|
||||
)
|
||||
class CrossEncoderRanker(Ranker):
|
||||
"""CrossEncoder Ranker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk: int = 4,
|
||||
model: str = "BAAI/bge-reranker-base",
|
||||
device: str = "cpu",
|
||||
rank_fn: Optional[RANK_FUNC] = None,
|
||||
):
|
||||
"""Cross Encoder rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
topk: int - The number of top k documents.
|
||||
model: str - rerank model name, e.g., 'BAAI/bge-reranker-base'.
|
||||
device: str - device name, e.g., 'cpu'.
|
||||
rank_fn: Optional[callable] - The rank function.
|
||||
Refer: https://www.sbert.net/examples/applications/cross-encoder/README.html
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"please `pip install sentence-transformers`",
|
||||
)
|
||||
self._model = CrossEncoder(model, max_length=512, device=device)
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Cross Encoder rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Chunk], candidates with scores
|
||||
query: Optional[str], query text
|
||||
Returns:
|
||||
List[Chunk], reranked candidates
|
||||
"""
|
||||
contents = [candidate.content for candidate in candidates_with_scores]
|
||||
query_content_pairs = [
|
||||
[
|
||||
query,
|
||||
content,
|
||||
]
|
||||
for content in contents
|
||||
]
|
||||
rank_scores = self._model.predict(sentences=query_content_pairs)
|
||||
|
||||
for candidate, score in zip(candidates_with_scores, rank_scores):
|
||||
candidate.score = score
|
||||
|
||||
new_candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
70
examples/rag/cross_encoder_rerank_example.py
Normal file
70
examples/rag/cross_encoder_rerank_example.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""This example demonstrates how to use the cross-encoder reranker
|
||||
to rerank the retrieved chunks.
|
||||
The cross-encoder reranker is a neural network model that takes a query
|
||||
and a chunk as input and outputs a score that represents the relevance of the chunk
|
||||
to the query.
|
||||
|
||||
Download pretrained cross-encoder models can be found at https://huggingface.co/models.
|
||||
Example:
|
||||
python examples/rag/cross_encoder_rerank_example.py
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.retriever.rerank import CrossEncoderRanker
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
print(f"persist_path:{os.path.join(PILOT_PATH, 'data')}")
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="example_cross_encoder_rerank",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
vector_connector = _create_vector_connector()
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
|
||||
# get embedding assembler
|
||||
assembler = EmbeddingAssembler.load_from_knowledge(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
# get embeddings retriever
|
||||
retriever = assembler.as_retriever(3)
|
||||
# create metadata filter
|
||||
query = "what is awel talk about"
|
||||
chunks = await retriever.aretrieve_with_scores(query, 0.3)
|
||||
|
||||
print("before rerank results:\n")
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}")
|
||||
# cross-encoder rerank
|
||||
cross_encoder_model = os.path.join(MODEL_PATH, "bge-reranker-base")
|
||||
rerank = CrossEncoderRanker(topk=3, model=cross_encoder_model)
|
||||
new_chunks = rerank.rank(chunks, query=query)
|
||||
print("after cross-encoder rerank results:\n")
|
||||
for i, chunk in enumerate(new_chunks):
|
||||
print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user