feat(RAG):add cross-encoder rerank (#1442)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt
2024-04-24 14:12:44 +08:00
committed by GitHub
parent dce03862d5
commit 91c1371234
2 changed files with 175 additions and 8 deletions

View File

@@ -24,7 +24,9 @@ class Ranker(ABC):
self.rank_fn = rank_fn
@abstractmethod
def rank(self, candidates_with_scores: List) -> List[Chunk]:
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Return top k chunks after ranker.
Rank algorithm implementation return topk documents by candidates
@@ -32,10 +34,9 @@ class Ranker(ABC):
Args:
candidates_with_scores: List[Tuple]
topk: int
query: Optional[str]
Return:
List[Document]
List[Chunk]
"""
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
@@ -77,11 +78,17 @@ class Ranker(ABC):
class DefaultRanker(Ranker):
"""Default Ranker."""
def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None):
def __init__(
self,
topk: int = 4,
rank_fn: Optional[RANK_FUNC] = None,
):
"""Create Default Ranker with topk and rank_fn."""
super().__init__(topk, rank_fn)
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Return top k chunks after ranker.
Return top k documents by candidates similarity score
@@ -105,11 +112,17 @@ class DefaultRanker(Ranker):
class RRFRanker(Ranker):
"""RRF(Reciprocal Rank Fusion) Ranker."""
def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None):
def __init__(
self,
topk: int = 4,
rank_fn: Optional[RANK_FUNC] = None,
):
"""RRF rank algorithm implementation."""
super().__init__(topk, rank_fn)
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""RRF rank algorithm implementation.
This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a
@@ -128,3 +141,87 @@ class RRFRanker(Ranker):
"""
# it will be implemented soon when multi recall is implemented
return candidates_with_scores
@register_resource(
_("CrossEncoder Rerank"),
"cross_encoder_ranker",
category=ResourceCategory.RAG,
description=_("CrossEncoder ranker."),
parameters=[
Parameter.build_from(
_("Top k"),
"topk",
int,
description=_("The number of top k documents."),
),
Parameter.build_from(
_("Rerank Model"),
"model",
str,
description=_("rerank model name, e.g., 'BAAI/bge-reranker-base'."),
),
Parameter.build_from(
_("device"),
"device",
str,
description=_("device name, e.g., 'cpu'."),
),
],
)
class CrossEncoderRanker(Ranker):
"""CrossEncoder Ranker."""
def __init__(
self,
topk: int = 4,
model: str = "BAAI/bge-reranker-base",
device: str = "cpu",
rank_fn: Optional[RANK_FUNC] = None,
):
"""Cross Encoder rank algorithm implementation.
Args:
topk: int - The number of top k documents.
model: str - rerank model name, e.g., 'BAAI/bge-reranker-base'.
device: str - device name, e.g., 'cpu'.
rank_fn: Optional[callable] - The rank function.
Refer: https://www.sbert.net/examples/applications/cross-encoder/README.html
"""
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError(
"please `pip install sentence-transformers`",
)
self._model = CrossEncoder(model, max_length=512, device=device)
super().__init__(topk, rank_fn)
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Cross Encoder rank algorithm implementation.
Args:
candidates_with_scores: List[Chunk], candidates with scores
query: Optional[str], query text
Returns:
List[Chunk], reranked candidates
"""
contents = [candidate.content for candidate in candidates_with_scores]
query_content_pairs = [
[
query,
content,
]
for content in contents
]
rank_scores = self._model.predict(sentences=query_content_pairs)
for candidate, score in zip(candidates_with_scores, rank_scores):
candidate.score = score
new_candidates_with_scores = sorted(
candidates_with_scores, key=lambda x: x.score, reverse=True
)
return new_candidates_with_scores[: self.topk]

View File

@@ -0,0 +1,70 @@
"""This example demonstrates how to use the cross-encoder reranker
to rerank the retrieved chunks.
The cross-encoder reranker is a neural network model that takes a query
and a chunk as input and outputs a score that represents the relevance of the chunk
to the query.
Download pretrained cross-encoder models can be found at https://huggingface.co/models.
Example:
python examples/rag/cross_encoder_rerank_example.py
"""
import asyncio
import os
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
from dbgpt.rag import ChunkParameters
from dbgpt.rag.assembler import EmbeddingAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.knowledge import KnowledgeFactory
from dbgpt.rag.retriever.rerank import CrossEncoderRanker
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
def _create_vector_connector():
"""Create vector connector."""
print(f"persist_path:{os.path.join(PILOT_PATH, 'data')}")
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="example_cross_encoder_rerank",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
async def main():
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
knowledge = KnowledgeFactory.from_file_path(file_path)
vector_connector = _create_vector_connector()
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_MARKDOWN_HEADER")
# get embedding assembler
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_connector,
)
assembler.persist()
# get embeddings retriever
retriever = assembler.as_retriever(3)
# create metadata filter
query = "what is awel talk about"
chunks = await retriever.aretrieve_with_scores(query, 0.3)
print("before rerank results:\n")
for i, chunk in enumerate(chunks):
print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}")
# cross-encoder rerank
cross_encoder_model = os.path.join(MODEL_PATH, "bge-reranker-base")
rerank = CrossEncoderRanker(topk=3, model=cross_encoder_model)
new_chunks = rerank.rank(chunks, query=query)
print("after cross-encoder rerank results:\n")
for i, chunk in enumerate(new_chunks):
print(f"----{i+1}.chunk content:{chunk.content}\n score:{chunk.score}")
if __name__ == "__main__":
asyncio.run(main())