feat(RAG):add cross-encoder rerank (#1442)

Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
Aries-ckt
2024-04-24 14:12:44 +08:00
committed by GitHub
parent dce03862d5
commit 91c1371234
2 changed files with 175 additions and 8 deletions

View File

@@ -24,7 +24,9 @@ class Ranker(ABC):
self.rank_fn = rank_fn
@abstractmethod
def rank(self, candidates_with_scores: List) -> List[Chunk]:
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Return top k chunks after ranker.
Rank algorithm implementation return topk documents by candidates
@@ -32,10 +34,9 @@ class Ranker(ABC):
Args:
candidates_with_scores: List[Tuple]
topk: int
query: Optional[str]
Return:
List[Document]
List[Chunk]
"""
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
@@ -77,11 +78,17 @@ class Ranker(ABC):
class DefaultRanker(Ranker):
"""Default Ranker."""
def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None):
def __init__(
self,
topk: int = 4,
rank_fn: Optional[RANK_FUNC] = None,
):
"""Create Default Ranker with topk and rank_fn."""
super().__init__(topk, rank_fn)
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Return top k chunks after ranker.
Return top k documents by candidates similarity score
@@ -105,11 +112,17 @@ class DefaultRanker(Ranker):
class RRFRanker(Ranker):
"""RRF(Reciprocal Rank Fusion) Ranker."""
def __init__(self, topk: int, rank_fn: Optional[RANK_FUNC] = None):
def __init__(
self,
topk: int = 4,
rank_fn: Optional[RANK_FUNC] = None,
):
"""RRF rank algorithm implementation."""
super().__init__(topk, rank_fn)
def rank(self, candidates_with_scores: List[Chunk]) -> List[Chunk]:
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""RRF rank algorithm implementation.
This code implements an algorithm called Reciprocal Rank Fusion (RRF), is a
@@ -128,3 +141,87 @@ class RRFRanker(Ranker):
"""
# it will be implemented soon when multi recall is implemented
return candidates_with_scores
@register_resource(
_("CrossEncoder Rerank"),
"cross_encoder_ranker",
category=ResourceCategory.RAG,
description=_("CrossEncoder ranker."),
parameters=[
Parameter.build_from(
_("Top k"),
"topk",
int,
description=_("The number of top k documents."),
),
Parameter.build_from(
_("Rerank Model"),
"model",
str,
description=_("rerank model name, e.g., 'BAAI/bge-reranker-base'."),
),
Parameter.build_from(
_("device"),
"device",
str,
description=_("device name, e.g., 'cpu'."),
),
],
)
class CrossEncoderRanker(Ranker):
"""CrossEncoder Ranker."""
def __init__(
self,
topk: int = 4,
model: str = "BAAI/bge-reranker-base",
device: str = "cpu",
rank_fn: Optional[RANK_FUNC] = None,
):
"""Cross Encoder rank algorithm implementation.
Args:
topk: int - The number of top k documents.
model: str - rerank model name, e.g., 'BAAI/bge-reranker-base'.
device: str - device name, e.g., 'cpu'.
rank_fn: Optional[callable] - The rank function.
Refer: https://www.sbert.net/examples/applications/cross-encoder/README.html
"""
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError(
"please `pip install sentence-transformers`",
)
self._model = CrossEncoder(model, max_length=512, device=device)
super().__init__(topk, rank_fn)
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Cross Encoder rank algorithm implementation.
Args:
candidates_with_scores: List[Chunk], candidates with scores
query: Optional[str], query text
Returns:
List[Chunk], reranked candidates
"""
contents = [candidate.content for candidate in candidates_with_scores]
query_content_pairs = [
[
query,
content,
]
for content in contents
]
rank_scores = self._model.predict(sentences=query_content_pairs)
for candidate, score in zip(candidates_with_scores, rank_scores):
candidate.score = score
new_candidates_with_scores = sorted(
candidates_with_scores, key=lambda x: x.score, reverse=True
)
return new_candidates_with_scores[: self.topk]